Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 52 additions & 3 deletions flang/lib/Lower/OpenMP/ClauseProcessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -916,6 +916,30 @@ bool ClauseProcessor::processIsDevicePtr(
});
}

bool ClauseProcessor::processInReduction(
mlir::Location currentLocation, mlir::omp::InReductionClauseOps &result,
llvm::SmallVectorImpl<const semantics::Symbol *> &outReductionSyms) const {
return findRepeatableClause<omp::clause::InReduction>(
[&](const omp::clause::InReduction &clause, const parser::CharBlock &) {
llvm::SmallVector<mlir::Value> inReductionVars;
llvm::SmallVector<bool> inReduceVarByRef;
llvm::SmallVector<mlir::Attribute> inReductionDeclSymbols;
llvm::SmallVector<const semantics::Symbol *> inReductionSyms;
ReductionProcessor rp;
rp.addDeclareReduction<omp::clause::InReduction>(
currentLocation, converter, clause, inReductionVars,
inReduceVarByRef, inReductionDeclSymbols, inReductionSyms);

// Copy local lists into the output.
llvm::copy(inReductionVars, std::back_inserter(result.inReductionVars));
llvm::copy(inReduceVarByRef,
std::back_inserter(result.inReductionByref));
llvm::copy(inReductionDeclSymbols,
std::back_inserter(result.inReductionSyms));
llvm::copy(inReductionSyms, std::back_inserter(outReductionSyms));
});
}

bool ClauseProcessor::processLink(
llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const {
return findRepeatableClause<omp::clause::Link>(
Expand Down Expand Up @@ -1126,9 +1150,10 @@ bool ClauseProcessor::processReduction(
llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
llvm::SmallVector<const semantics::Symbol *> reductionSyms;
ReductionProcessor rp;
rp.addDeclareReduction(currentLocation, converter, clause,
reductionVars, reduceVarByRef,
reductionDeclSymbols, reductionSyms);

rp.addDeclareReduction<omp::clause::Reduction>(
currentLocation, converter, clause, reductionVars, reduceVarByRef,
reductionDeclSymbols, reductionSyms);

// Copy local lists into the output.
llvm::copy(reductionVars, std::back_inserter(result.reductionVars));
Expand All @@ -1139,6 +1164,30 @@ bool ClauseProcessor::processReduction(
});
}

bool ClauseProcessor::processTaskReduction(
mlir::Location currentLocation, mlir::omp::TaskReductionClauseOps &result,
llvm::SmallVectorImpl<const semantics::Symbol *> &outReductionSyms) const {
return findRepeatableClause<omp::clause::TaskReduction>(
[&](const omp::clause::TaskReduction &clause, const parser::CharBlock &) {
llvm::SmallVector<mlir::Value> taskReductionVars;
llvm::SmallVector<bool> TaskReduceVarByRef;
llvm::SmallVector<mlir::Attribute> TaskReductionDeclSymbols;
llvm::SmallVector<const semantics::Symbol *> TaskReductionSyms;
ReductionProcessor rp;
rp.addDeclareReduction<omp::clause::TaskReduction>(
currentLocation, converter, clause, taskReductionVars,
TaskReduceVarByRef, TaskReductionDeclSymbols, TaskReductionSyms);
// Copy local lists into the output.
llvm::copy(taskReductionVars,
std::back_inserter(result.taskReductionVars));
llvm::copy(TaskReduceVarByRef,
std::back_inserter(result.taskReductionByref));
llvm::copy(TaskReductionDeclSymbols,
std::back_inserter(result.taskReductionSyms));
llvm::copy(TaskReductionSyms, std::back_inserter(outReductionSyms));
});
}

bool ClauseProcessor::processTo(
llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const {
return findRepeatableClause<omp::clause::To>(
Expand Down
6 changes: 6 additions & 0 deletions flang/lib/Lower/OpenMP/ClauseProcessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ class ClauseProcessor {
bool processIsDevicePtr(
mlir::omp::IsDevicePtrClauseOps &result,
llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const;
bool processInReduction(
mlir::Location currentLocation, mlir::omp::InReductionClauseOps &result,
llvm::SmallVectorImpl<const semantics::Symbol *> &outReductionSyms) const;
bool
processLink(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;

Expand All @@ -123,6 +126,9 @@ class ClauseProcessor {
bool processReduction(
mlir::Location currentLocation, mlir::omp::ReductionClauseOps &result,
llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSyms) const;
bool processTaskReduction(
mlir::Location currentLocation, mlir::omp::TaskReductionClauseOps &result,
llvm::SmallVectorImpl<const semantics::Symbol *> &outReductionSyms) const;
bool processTo(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
bool processUseDeviceAddr(
lower::StatementContext &stmtCtx,
Expand Down
9 changes: 8 additions & 1 deletion flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -344,8 +344,15 @@ void DataSharingProcessor::collectSymbols(
// Collect all symbols referenced in the evaluation being processed,
// that matches 'flag'.
llvm::SetVector<const semantics::Symbol *> allSymbols;

auto itr = llvm::find_if(clauses, [](const omp::Clause &clause) {
return clause.id == llvm::omp::Clause::OMPC_in_reduction;
});

bool collectSymbols = (itr == clauses.end());

converter.collectSymbolSet(eval, allSymbols, flag,
/*collectSymbols=*/true,
/*collectSymbols=*/collectSymbols,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: The comment is redundant when passing a variable instead of a constant.

Suggested change
/*collectSymbols=*/collectSymbols,
collectSymbols,

/*collectHostAssociatedSymbols=*/true);

llvm::SetVector<const semantics::Symbol *> symbolsInNestedRegions;
Expand Down
66 changes: 44 additions & 22 deletions flang/lib/Lower/OpenMP/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1249,34 +1249,34 @@ static void genTargetEnterExitUpdateDataClauses(
cp.processNowait(clauseOps);
}

static void genTaskClauses(lower::AbstractConverter &converter,
semantics::SemanticsContext &semaCtx,
lower::StatementContext &stmtCtx,
const List<Clause> &clauses, mlir::Location loc,
mlir::omp::TaskOperands &clauseOps) {
static void genTaskClauses(
lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx,
lower::StatementContext &stmtCtx, const List<Clause> &clauses,
mlir::Location loc, mlir::omp::TaskOperands &clauseOps,
llvm::SmallVectorImpl<const semantics::Symbol *> &inReductionSyms) {
ClauseProcessor cp(converter, semaCtx, clauses);
cp.processAllocate(clauseOps);
cp.processDepend(clauseOps);
cp.processFinal(stmtCtx, clauseOps);
cp.processIf(llvm::omp::Directive::OMPD_task, clauseOps);
cp.processInReduction(loc, clauseOps, inReductionSyms);
cp.processMergeable(clauseOps);
cp.processPriority(stmtCtx, clauseOps);
cp.processUntied(clauseOps);
cp.processDetach(clauseOps);
// TODO Support delayed privatization.

cp.processTODO<clause::Affinity, clause::InReduction>(
loc, llvm::omp::Directive::OMPD_task);
cp.processTODO<clause::Affinity>(loc, llvm::omp::Directive::OMPD_task);
}

static void genTaskgroupClauses(lower::AbstractConverter &converter,
semantics::SemanticsContext &semaCtx,
const List<Clause> &clauses, mlir::Location loc,
mlir::omp::TaskgroupOperands &clauseOps) {
static void genTaskgroupClauses(
lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx,
const List<Clause> &clauses, mlir::Location loc,
mlir::omp::TaskgroupOperands &clauseOps,
llvm::SmallVectorImpl<const semantics::Symbol *> &taskReductionSyms) {
ClauseProcessor cp(converter, semaCtx, clauses);
cp.processAllocate(clauseOps);
cp.processTODO<clause::TaskReduction>(loc,
llvm::omp::Directive::OMPD_taskgroup);
cp.processTaskReduction(loc, clauseOps, taskReductionSyms);
}

static void genTaskwaitClauses(lower::AbstractConverter &converter,
Expand Down Expand Up @@ -1887,7 +1887,9 @@ genTaskOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
ConstructQueue::const_iterator item) {
lower::StatementContext stmtCtx;
mlir::omp::TaskOperands clauseOps;
genTaskClauses(converter, semaCtx, stmtCtx, item->clauses, loc, clauseOps);
llvm::SmallVector<const semantics::Symbol *> inReductionSyms;
genTaskClauses(converter, semaCtx, stmtCtx, item->clauses, loc, clauseOps,
inReductionSyms);

if (!enableDelayedPrivatization)
return genOpWithBody<mlir::omp::TaskOp>(
Expand All @@ -1904,22 +1906,27 @@ genTaskOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
EntryBlockArgs taskArgs;
taskArgs.priv.syms = dsp.getDelayedPrivSymbols();
taskArgs.priv.vars = clauseOps.privateVars;
taskArgs.inReduction.syms = inReductionSyms;
taskArgs.inReduction.vars = clauseOps.inReductionVars;

auto genRegionEntryCB = [&](mlir::Operation *op) {
genEntryBlock(converter.getFirOpBuilder(), taskArgs, op->getRegion(0));
bindEntryBlockArgs(converter,
llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op),
taskArgs);
return llvm::to_vector(taskArgs.priv.syms);
return llvm::to_vector(taskArgs.getSyms());
};

return genOpWithBody<mlir::omp::TaskOp>(
OpWithBodyGenInfo genInfo =
OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
llvm::omp::Directive::OMPD_task)
.setClauses(&item->clauses)
.setDataSharingProcessor(&dsp)
.setGenRegionEntryCb(genRegionEntryCB),
queue, item, clauseOps);
.setGenRegionEntryCb(genRegionEntryCB);

auto taskOp =
genOpWithBody<mlir::omp::TaskOp>(genInfo, queue, item, clauseOps);
return taskOp;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: It doesn't look like this change is necessary. Since the general convention in this file is to return the result of genOpWithBody directly and construct the OpWithBodyGenInfo structure parameter inside of the call whenever possible, I think this should be left as it was.

}

static mlir::omp::TaskgroupOp
Expand All @@ -1929,13 +1936,28 @@ genTaskgroupOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
const ConstructQueue &queue,
ConstructQueue::const_iterator item) {
mlir::omp::TaskgroupOperands clauseOps;
genTaskgroupClauses(converter, semaCtx, item->clauses, loc, clauseOps);
llvm::SmallVector<const semantics::Symbol *> taskReductionSyms;
genTaskgroupClauses(converter, semaCtx, item->clauses, loc, clauseOps,
taskReductionSyms);

return genOpWithBody<mlir::omp::TaskgroupOp>(
EntryBlockArgs taskgroupArgs;
taskgroupArgs.taskReduction.syms = taskReductionSyms;
taskgroupArgs.taskReduction.vars = clauseOps.taskReductionVars;

auto genRegionEntryCB = [&](mlir::Operation *op) {
genEntryBlock(converter.getFirOpBuilder(), taskgroupArgs, op->getRegion(0));
return llvm::to_vector(taskgroupArgs.getSyms());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is missing the binding of the symbols to the new entry block arguments:

Suggested change
genEntryBlock(converter.getFirOpBuilder(), taskgroupArgs, op->getRegion(0));
return llvm::to_vector(taskgroupArgs.getSyms());
genEntryBlock(converter.getFirOpBuilder(), taskgroupArgs, op->getRegion(0));
bindEntryBlockArgs(converter,
llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op),
taskgroupArgs);
return llvm::to_vector(taskgroupArgs.getSyms());

};

OpWithBodyGenInfo genInfo =
OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
llvm::omp::Directive::OMPD_taskgroup)
.setClauses(&item->clauses),
queue, item, clauseOps);
.setClauses(&item->clauses)
.setGenRegionEntryCb(genRegionEntryCB);

auto taskgroupOp =
genOpWithBody<mlir::omp::TaskgroupOp>(genInfo, queue, item, clauseOps);
return taskgroupOp;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Same comment as above. For consistency, return directly the result of genOpWithBody and construct the genInfo argument within the argument list.

}

static mlir::omp::TaskwaitOp
Expand Down
41 changes: 34 additions & 7 deletions flang/lib/Lower/OpenMP/ReductionProcessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "flang/Parser/tools.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "llvm/Support/CommandLine.h"
#include <type_traits>

static llvm::cl::opt<bool> forceByrefReduction(
"force-byref-reduction",
Expand All @@ -34,6 +35,32 @@ namespace Fortran {
namespace lower {
namespace omp {

// explicit template declarations
template void ReductionProcessor::addDeclareReduction<omp::clause::Reduction>(
mlir::Location currentLocation, lower::AbstractConverter &converter,
const omp::clause::Reduction &reduction,
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
llvm::SmallVectorImpl<bool> &reduceVarByRef,
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols);

template void
ReductionProcessor::addDeclareReduction<omp::clause::TaskReduction>(
mlir::Location currentLocation, lower::AbstractConverter &converter,
const omp::clause::TaskReduction &reduction,
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
llvm::SmallVectorImpl<bool> &reduceVarByRef,
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols);

template void ReductionProcessor::addDeclareReduction<omp::clause::InReduction>(
mlir::Location currentLocation, lower::AbstractConverter &converter,
const omp::clause::InReduction &reduction,
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
llvm::SmallVectorImpl<bool> &reduceVarByRef,
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols);

ReductionProcessor::ReductionIdentifier ReductionProcessor::getReductionType(
const omp::clause::ProcedureDesignator &pd) {
auto redType = llvm::StringSwitch<std::optional<ReductionIdentifier>>(
Expand Down Expand Up @@ -716,22 +743,22 @@ static bool doReductionByRef(mlir::Value reductionVar) {
return false;
}

template <class T>
void ReductionProcessor::addDeclareReduction(
mlir::Location currentLocation, lower::AbstractConverter &converter,
const omp::clause::Reduction &reduction,
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
const T &reduction, llvm::SmallVectorImpl<mlir::Value> &reductionVars,
llvm::SmallVectorImpl<bool> &reduceVarByRef,
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();

if (std::get<std::optional<omp::clause::Reduction::ReductionModifier>>(
reduction.t))
TODO(currentLocation, "Reduction modifiers are not supported");
if constexpr (std::is_same<T, omp::clause::Reduction>::value) {
if (std::get<std::optional<typename T::ReductionModifier>>(reduction.t))
TODO(currentLocation, "Reduction modifiers are not supported");
}

mlir::omp::DeclareReductionOp decl;
const auto &redOperatorList{
std::get<omp::clause::Reduction::ReductionIdentifiers>(reduction.t)};
std::get<typename T::ReductionIdentifiers>(reduction.t)};
assert(redOperatorList.size() == 1 && "Expecting single operator");
const auto &redOperator = redOperatorList.front();
const auto &objectList{std::get<omp::ObjectList>(reduction.t)};
Expand Down
4 changes: 2 additions & 2 deletions flang/lib/Lower/OpenMP/ReductionProcessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,10 @@ class ReductionProcessor {

/// Creates a reduction declaration and associates it with an OpenMP block
/// directive.
template <class T>
static void addDeclareReduction(
mlir::Location currentLocation, lower::AbstractConverter &converter,
const omp::clause::Reduction &reduction,
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
const T &reduction, llvm::SmallVectorImpl<mlir::Value> &reductionVars,
llvm::SmallVectorImpl<bool> &reduceVarByRef,
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSymbols);
Expand Down
15 changes: 0 additions & 15 deletions flang/test/Lower/OpenMP/Todo/task-inreduction.f90

This file was deleted.

10 changes: 0 additions & 10 deletions flang/test/Lower/OpenMP/Todo/taskgroup-task-reduction.f90

This file was deleted.

35 changes: 35 additions & 0 deletions flang/test/Lower/OpenMP/task-inreduction.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
! RUN: bbc -emit-hlfir -fopenmp -fopenmp-version=50 -o - %s 2>&1 | FileCheck %s
! RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=50 -o - %s 2>&1 | FileCheck %s

!CHECK-LABEL: omp.declare_reduction
!CHECK-SAME: @[[RED_I32_NAME:.*]] : i32 init {
!CHECK: ^bb0(%{{.*}}: i32):
!CHECK: %[[C0_1:.*]] = arith.constant 0 : i32
!CHECK: omp.yield(%[[C0_1]] : i32)
!CHECK: } combiner {
!CHECK: ^bb0(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32):
!CHECK: %[[RES:.*]] = arith.addi %[[ARG0]], %[[ARG1]] : i32
!CHECK: omp.yield(%[[RES]] : i32)
!CHECK: }

!CHECK-LABEL: func.func @_QPomp_task_in_reduction() {
! [...]
!CHECK: omp.task in_reduction(@[[RED_I32_NAME]] %[[VAL_1:.*]]#0 -> %[[ARG0]] : !fir.ref<i32>) {
!CHECK: %[[VAL_4:.*]]:2 = hlfir.declare %[[ARG0]]
!CHECK-SAME: {uniq_name = "_QFomp_task_in_reductionEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[VAL_5:.*]] = fir.load %[[VAL_4]]#0 : !fir.ref<i32>
!CHECK: %[[VAL_6:.*]] = arith.constant 1 : i32
!CHECK: %[[VAL_7:.*]] = arith.addi %[[VAL_5]], %[[VAL_6]] : i32
!CHECK: hlfir.assign %[[VAL_7]] to %[[VAL_4]]#0 : i32, !fir.ref<i32>
!CHECK: omp.terminator
!CHECK: }
!CHECK: return
!CHECK: }

subroutine omp_task_in_reduction()
integer i
i = 0
!$omp task in_reduction(+:i)
i = i + 1
!$omp end task
end subroutine omp_task_in_reduction
Loading
Loading