Skip to content

Commit 1b5a47e

Browse files
committed
[Flang][OpenMP] Addressed review comments
1 parent 60cbcc2 commit 1b5a47e

File tree

7 files changed

+33
-39
lines changed

7 files changed

+33
-39
lines changed

flang/lib/Lower/OpenMP/ClauseProcessor.h

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ class ClauseProcessor {
107107
llvm::SmallVectorImpl<const semantics::Symbol *> &isDeviceSyms) const;
108108
bool processInReduction(
109109
mlir::Location currentLocation, mlir::omp::InReductionClauseOps &result,
110-
llvm::SmallVectorImpl<const semantics::Symbol *> &InReductionSyms) const;
110+
llvm::SmallVectorImpl<const semantics::Symbol *> &outReductionSyms) const;
111111
bool
112112
processLink(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
113113

@@ -126,10 +126,9 @@ class ClauseProcessor {
126126
bool processReduction(
127127
mlir::Location currentLocation, mlir::omp::ReductionClauseOps &result,
128128
llvm::SmallVectorImpl<const semantics::Symbol *> &reductionSyms) const;
129-
bool processTaskReduction(mlir::Location currentLocation,
130-
mlir::omp::TaskReductionClauseOps &result,
131-
llvm::SmallVectorImpl<const semantics::Symbol *>
132-
&TaskReductionSyms) const;
129+
bool processTaskReduction(
130+
mlir::Location currentLocation, mlir::omp::TaskReductionClauseOps &result,
131+
llvm::SmallVectorImpl<const semantics::Symbol *> &outReductionSyms) const;
133132
bool processTo(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
134133
bool processUseDeviceAddr(
135134
lower::StatementContext &stmtCtx,

flang/lib/Lower/OpenMP/DataSharingProcessor.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -344,11 +344,13 @@ void DataSharingProcessor::collectSymbols(
344344
// Collect all symbols referenced in the evaluation being processed,
345345
// that matches 'flag'.
346346
llvm::SetVector<const semantics::Symbol *> allSymbols;
347-
bool collectSymbols = true;
348-
for (const omp::Clause &clause : clauses) {
349-
if (clause.id == llvm::omp::Clause::OMPC_in_reduction)
350-
collectSymbols = false;
351-
}
347+
348+
auto itr = llvm::find_if(clauses, [](const omp::Clause &clause) {
349+
return clause.id == llvm::omp::Clause::OMPC_in_reduction;
350+
});
351+
352+
bool collectSymbols = (itr == clauses.end());
353+
352354
converter.collectSymbolSet(eval, allSymbols, flag,
353355
/*collectSymbols=*/collectSymbols,
354356
/*collectHostAssociatedSymbols=*/true);

flang/lib/Lower/OpenMP/OpenMP.cpp

Lines changed: 17 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1253,21 +1253,20 @@ static void genTaskClauses(
12531253
lower::AbstractConverter &converter, semantics::SemanticsContext &semaCtx,
12541254
lower::StatementContext &stmtCtx, const List<Clause> &clauses,
12551255
mlir::Location loc, mlir::omp::TaskOperands &clauseOps,
1256-
llvm::SmallVectorImpl<const semantics::Symbol *> &InReductionSyms) {
1256+
llvm::SmallVectorImpl<const semantics::Symbol *> &inReductionSyms) {
12571257
ClauseProcessor cp(converter, semaCtx, clauses);
12581258
cp.processAllocate(clauseOps);
12591259
cp.processDepend(clauseOps);
12601260
cp.processFinal(stmtCtx, clauseOps);
12611261
cp.processIf(llvm::omp::Directive::OMPD_task, clauseOps);
1262-
cp.processInReduction(loc, clauseOps, InReductionSyms);
1262+
cp.processInReduction(loc, clauseOps, inReductionSyms);
12631263
cp.processMergeable(clauseOps);
12641264
cp.processPriority(stmtCtx, clauseOps);
12651265
cp.processUntied(clauseOps);
12661266
cp.processDetach(clauseOps);
12671267
// TODO Support delayed privatization.
12681268

1269-
cp.processTODO<clause::Affinity>(
1270-
loc, llvm::omp::Directive::OMPD_task);
1269+
cp.processTODO<clause::Affinity>(loc, llvm::omp::Directive::OMPD_task);
12711270
}
12721271

12731272
static void genTaskgroupClauses(
@@ -1888,9 +1887,9 @@ genTaskOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
18881887
ConstructQueue::const_iterator item) {
18891888
lower::StatementContext stmtCtx;
18901889
mlir::omp::TaskOperands clauseOps;
1891-
llvm::SmallVector<const semantics::Symbol *> InReductionSyms;
1890+
llvm::SmallVector<const semantics::Symbol *> inReductionSyms;
18921891
genTaskClauses(converter, semaCtx, stmtCtx, item->clauses, loc, clauseOps,
1893-
InReductionSyms);
1892+
inReductionSyms);
18941893

18951894
if (!enableDelayedPrivatization)
18961895
return genOpWithBody<mlir::omp::TaskOp>(
@@ -1907,7 +1906,7 @@ genTaskOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
19071906
EntryBlockArgs taskArgs;
19081907
taskArgs.priv.syms = dsp.getDelayedPrivSymbols();
19091908
taskArgs.priv.vars = clauseOps.privateVars;
1910-
taskArgs.inReduction.syms = InReductionSyms;
1909+
taskArgs.inReduction.syms = inReductionSyms;
19111910
taskArgs.inReduction.vars = clauseOps.inReductionVars;
19121911

19131912
auto genRegionEntryCB = [&](mlir::Operation *op) {
@@ -1927,14 +1926,6 @@ genTaskOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
19271926

19281927
auto taskOp =
19291928
genOpWithBody<mlir::omp::TaskOp>(genInfo, queue, item, clauseOps);
1930-
1931-
llvm::SmallVector<mlir::Type> inReductionTypes;
1932-
for (const auto &inreductionVar : clauseOps.inReductionVars)
1933-
inReductionTypes.push_back(inreductionVar.getType());
1934-
1935-
// Add reduction variables as entry block arguments to the task region
1936-
llvm::SmallVector<mlir::Location> blockArgLocs(InReductionSyms.size(), loc);
1937-
taskOp->getRegion(0).addArguments(inReductionTypes, blockArgLocs);
19381929
return taskOp;
19391930
}
19401931

@@ -1949,21 +1940,23 @@ genTaskgroupOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
19491940
genTaskgroupClauses(converter, semaCtx, item->clauses, loc, clauseOps,
19501941
taskReductionSyms);
19511942

1943+
EntryBlockArgs taskgroupArgs;
1944+
taskgroupArgs.taskReduction.syms = taskReductionSyms;
1945+
taskgroupArgs.taskReduction.vars = clauseOps.taskReductionVars;
1946+
1947+
auto genRegionEntryCB = [&](mlir::Operation *op) {
1948+
genEntryBlock(converter.getFirOpBuilder(), taskgroupArgs, op->getRegion(0));
1949+
return llvm::to_vector(taskgroupArgs.getSyms());
1950+
};
1951+
19521952
OpWithBodyGenInfo genInfo =
19531953
OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
19541954
llvm::omp::Directive::OMPD_taskgroup)
1955-
.setClauses(&item->clauses);
1955+
.setClauses(&item->clauses)
1956+
.setGenRegionEntryCb(genRegionEntryCB);
19561957

19571958
auto taskgroupOp =
19581959
genOpWithBody<mlir::omp::TaskgroupOp>(genInfo, queue, item, clauseOps);
1959-
1960-
llvm::SmallVector<mlir::Type> taskReductionTypes;
1961-
for (const auto &taskreductionVar : clauseOps.taskReductionVars)
1962-
taskReductionTypes.push_back(taskreductionVar.getType());
1963-
1964-
// Add reduction variables as entry block arguments to the taskgroup region
1965-
llvm::SmallVector<mlir::Location> blockArgLocs(taskReductionSyms.size(), loc);
1966-
taskgroupOp->getRegion(0).addArguments(taskReductionTypes, blockArgLocs);
19671960
return taskgroupOp;
19681961
}
19691962

flang/test/Lower/OpenMP/task-inreduction.f90

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,4 +32,4 @@ subroutine omp_task_in_reduction()
3232
!$omp task in_reduction(+:i)
3333
i = i + 1
3434
!$omp end task
35-
end subroutine omp_task_in_reduction
35+
end subroutine omp_task_in_reduction

flang/test/Lower/OpenMP/taskgroup-task-array-reduction.f90

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,4 +46,4 @@ subroutine taskReduction(x)
4646
!$omp end task
4747
!$omp end taskgroup
4848
!$omp end parallel
49-
end subroutine
49+
end subroutine

flang/test/Lower/OpenMP/taskgroup-task_reduction01.f90

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,4 @@ subroutine omp_taskgroup_task_reduction()
3131
!$omp taskgroup task_reduction(+:res)
3232
res = res + 1
3333
!$omp end taskgroup
34-
end subroutine omp_taskgroup_task_reduction
34+
end subroutine

flang/test/Lower/OpenMP/taskgroup-task_reduction02.f90

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,12 @@
2525
!CHECK: return
2626
!CHECK: }
2727

28-
subroutine in_reduction
28+
subroutine in_reduction()
2929
integer :: x
3030
x = 0
3131
!$omp taskgroup task_reduction(+:x)
3232
!$omp task in_reduction(+:x)
3333
x = x + 1
3434
!$omp end task
3535
!$omp end taskgroup
36-
end subroutine
36+
end subroutine

0 commit comments

Comments
 (0)