Skip to content

Commit 6987182

Browse files
committed
[flang][OpenMP] do concurrent: support reduce on device
Extends `do concurrent` to OpenMP device mapping by adding support for mapping `reduce` specifiers to omp `reduction` clauses. The changes attach 2 `reduction` clauses to the mapped OpenMP construct: one on the `teams` part of the construct and one on the `wloop` part.
1 parent 78e1013 commit 6987182

File tree

2 files changed

+121
-49
lines changed

2 files changed

+121
-49
lines changed

flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp

Lines changed: 68 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,9 @@ void collectLoopLiveIns(fir::DoConcurrentLoopOp loop,
140140

141141
for (mlir::Value local : loop.getLocalVars())
142142
liveIns.push_back(local);
143+
144+
for (mlir::Value reduce : loop.getReduceVars())
145+
liveIns.push_back(reduce);
143146
}
144147

145148
/// Collects values that are local to a loop: "loop-local values". A loop-local
@@ -272,7 +275,7 @@ class DoConcurrentConversion
272275
targetOp =
273276
genTargetOp(doLoop.getLoc(), rewriter, mapper, loopNestLiveIns,
274277
targetClauseOps, loopNestClauseOps, liveInShapeInfoMap);
275-
genTeamsOp(doLoop.getLoc(), rewriter);
278+
genTeamsOp(rewriter, loop, mapper);
276279
}
277280

278281
mlir::omp::ParallelOp parallelOp =
@@ -488,46 +491,7 @@ class DoConcurrentConversion
488491
if (!mapToDevice)
489492
genPrivatizers(rewriter, mapper, loop, wsloopClauseOps);
490493

491-
if (!loop.getReduceVars().empty()) {
492-
for (auto [op, byRef, sym, arg] : llvm::zip_equal(
493-
loop.getReduceVars(), loop.getReduceByrefAttr().asArrayRef(),
494-
loop.getReduceSymsAttr().getAsRange<mlir::SymbolRefAttr>(),
495-
loop.getRegionReduceArgs())) {
496-
auto firReducer = moduleSymbolTable.lookup<fir::DeclareReductionOp>(
497-
sym.getLeafReference());
498-
499-
mlir::OpBuilder::InsertionGuard guard(rewriter);
500-
rewriter.setInsertionPointAfter(firReducer);
501-
std::string ompReducerName = sym.getLeafReference().str() + ".omp";
502-
503-
auto ompReducer =
504-
moduleSymbolTable.lookup<mlir::omp::DeclareReductionOp>(
505-
rewriter.getStringAttr(ompReducerName));
506-
507-
if (!ompReducer) {
508-
ompReducer = mlir::omp::DeclareReductionOp::create(
509-
rewriter, firReducer.getLoc(), ompReducerName,
510-
firReducer.getTypeAttr().getValue());
511-
512-
cloneFIRRegionToOMP(rewriter, firReducer.getAllocRegion(),
513-
ompReducer.getAllocRegion());
514-
cloneFIRRegionToOMP(rewriter, firReducer.getInitializerRegion(),
515-
ompReducer.getInitializerRegion());
516-
cloneFIRRegionToOMP(rewriter, firReducer.getReductionRegion(),
517-
ompReducer.getReductionRegion());
518-
cloneFIRRegionToOMP(rewriter, firReducer.getAtomicReductionRegion(),
519-
ompReducer.getAtomicReductionRegion());
520-
cloneFIRRegionToOMP(rewriter, firReducer.getCleanupRegion(),
521-
ompReducer.getCleanupRegion());
522-
moduleSymbolTable.insert(ompReducer);
523-
}
524-
525-
wsloopClauseOps.reductionVars.push_back(op);
526-
wsloopClauseOps.reductionByref.push_back(byRef);
527-
wsloopClauseOps.reductionSyms.push_back(
528-
mlir::SymbolRefAttr::get(ompReducer));
529-
}
530-
}
494+
genReductions(rewriter, mapper, loop, wsloopClauseOps);
531495

532496
auto wsloopOp =
533497
mlir::omp::WsloopOp::create(rewriter, loop.getLoc(), wsloopClauseOps);
@@ -549,8 +513,6 @@ class DoConcurrentConversion
549513

550514
rewriter.setInsertionPointToEnd(&loopNestOp.getRegion().back());
551515
mlir::omp::YieldOp::create(rewriter, loop->getLoc());
552-
loop->getParentOfType<mlir::ModuleOp>().print(
553-
llvm::errs(), mlir::OpPrintingFlags().assumeVerified());
554516

555517
return {loopNestOp, wsloopOp};
556518
}
@@ -771,15 +733,26 @@ class DoConcurrentConversion
771733
liveInName, shape);
772734
}
773735

774-
mlir::omp::TeamsOp
775-
genTeamsOp(mlir::Location loc,
776-
mlir::ConversionPatternRewriter &rewriter) const {
777-
auto teamsOp = rewriter.create<mlir::omp::TeamsOp>(
778-
loc, /*clauses=*/mlir::omp::TeamsOperands{});
736+
mlir::omp::TeamsOp genTeamsOp(mlir::ConversionPatternRewriter &rewriter,
737+
fir::DoConcurrentLoopOp loop,
738+
mlir::IRMapping &mapper) const {
739+
mlir::omp::TeamsOperands teamsOps;
740+
genReductions(rewriter, mapper, loop, teamsOps);
741+
742+
mlir::Location loc = loop.getLoc();
743+
auto teamsOp = rewriter.create<mlir::omp::TeamsOp>(loc, teamsOps);
744+
Fortran::common::openmp::EntryBlockArgs teamsArgs;
745+
teamsArgs.reduction.vars = teamsOps.reductionVars;
746+
Fortran::common::openmp::genEntryBlock(rewriter, teamsArgs,
747+
teamsOp.getRegion());
779748

780-
rewriter.createBlock(&teamsOp.getRegion());
781749
rewriter.setInsertionPoint(rewriter.create<mlir::omp::TerminatorOp>(loc));
782750

751+
for (auto [loopVar, teamsArg] : llvm::zip_equal(
752+
loop.getReduceVars(), teamsOp.getRegion().getArguments())) {
753+
mapper.map(loopVar, teamsArg);
754+
}
755+
783756
return teamsOp;
784757
}
785758

@@ -846,6 +819,52 @@ class DoConcurrentConversion
846819
}
847820
}
848821

822+
void genReductions(mlir::ConversionPatternRewriter &rewriter,
823+
mlir::IRMapping &mapper, fir::DoConcurrentLoopOp loop,
824+
mlir::omp::ReductionClauseOps &reductionClauseOps) const {
825+
if (!loop.getReduceVars().empty()) {
826+
for (auto [var, byRef, sym, arg] : llvm::zip_equal(
827+
loop.getReduceVars(), loop.getReduceByrefAttr().asArrayRef(),
828+
loop.getReduceSymsAttr().getAsRange<mlir::SymbolRefAttr>(),
829+
loop.getRegionReduceArgs())) {
830+
auto firReducer = moduleSymbolTable.lookup<fir::DeclareReductionOp>(
831+
sym.getLeafReference());
832+
833+
mlir::OpBuilder::InsertionGuard guard(rewriter);
834+
rewriter.setInsertionPointAfter(firReducer);
835+
std::string ompReducerName = sym.getLeafReference().str() + ".omp";
836+
837+
auto ompReducer =
838+
moduleSymbolTable.lookup<mlir::omp::DeclareReductionOp>(
839+
rewriter.getStringAttr(ompReducerName));
840+
841+
if (!ompReducer) {
842+
ompReducer = mlir::omp::DeclareReductionOp::create(
843+
rewriter, firReducer.getLoc(), ompReducerName,
844+
firReducer.getTypeAttr().getValue());
845+
846+
cloneFIRRegionToOMP(rewriter, firReducer.getAllocRegion(),
847+
ompReducer.getAllocRegion());
848+
cloneFIRRegionToOMP(rewriter, firReducer.getInitializerRegion(),
849+
ompReducer.getInitializerRegion());
850+
cloneFIRRegionToOMP(rewriter, firReducer.getReductionRegion(),
851+
ompReducer.getReductionRegion());
852+
cloneFIRRegionToOMP(rewriter, firReducer.getAtomicReductionRegion(),
853+
ompReducer.getAtomicReductionRegion());
854+
cloneFIRRegionToOMP(rewriter, firReducer.getCleanupRegion(),
855+
ompReducer.getCleanupRegion());
856+
moduleSymbolTable.insert(ompReducer);
857+
}
858+
859+
reductionClauseOps.reductionVars.push_back(
860+
mapToDevice ? mapper.lookup(var) : var);
861+
reductionClauseOps.reductionByref.push_back(byRef);
862+
reductionClauseOps.reductionSyms.push_back(
863+
mlir::SymbolRefAttr::get(ompReducer));
864+
}
865+
}
866+
}
867+
849868
bool mapToDevice;
850869
llvm::DenseSet<fir::DoConcurrentOp> &concurrentLoopsToSkip;
851870
mlir::SymbolTable &moduleSymbolTable;
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
// RUN: fir-opt --omp-do-concurrent-conversion="map-to=device" %s -o - | FileCheck %s
2+
3+
fir.declare_reduction @add_reduction_f32 : f32 init {
4+
^bb0(%arg0: f32):
5+
%cst = arith.constant 0.000000e+00 : f32
6+
fir.yield(%cst : f32)
7+
} combiner {
8+
^bb0(%arg0: f32, %arg1: f32):
9+
%0 = arith.addf %arg0, %arg1 fastmath<contract> : f32
10+
fir.yield(%0 : f32)
11+
}
12+
13+
func.func @_QPfoo() {
14+
%0 = fir.dummy_scope : !fir.dscope
15+
%3 = fir.alloca f32 {bindc_name = "s", uniq_name = "_QFfooEs"}
16+
%4:2 = hlfir.declare %3 {uniq_name = "_QFfooEs"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
17+
%c1 = arith.constant 1 : index
18+
%c10 = arith.constant 1 : index
19+
fir.do_concurrent {
20+
%7 = fir.alloca i32 {bindc_name = "i"}
21+
%8:2 = hlfir.declare %7 {uniq_name = "_QFfooEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
22+
fir.do_concurrent.loop (%arg0) = (%c1) to (%c10) step (%c1) reduce(@add_reduction_f32 #fir.reduce_attr<add> %4#0 -> %arg1 : !fir.ref<f32>) {
23+
%9 = fir.convert %arg0 : (index) -> i32
24+
fir.store %9 to %8#0 : !fir.ref<i32>
25+
%10:2 = hlfir.declare %arg1 {uniq_name = "_QFfooEs"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
26+
%11 = fir.load %10#0 : !fir.ref<f32>
27+
%cst = arith.constant 1.000000e+00 : f32
28+
%12 = arith.addf %11, %cst fastmath<contract> : f32
29+
hlfir.assign %12 to %10#0 : f32, !fir.ref<f32>
30+
}
31+
}
32+
return
33+
}
34+
35+
// CHECK: omp.declare_reduction @[[OMP_RED:.*.omp]] : f32
36+
37+
// CHECK: %[[S_DECL:.*]]:2 = hlfir.declare %6 {uniq_name = "_QFfooEs"}
38+
// CHECK: %[[S_MAP:.*]] = omp.map.info var_ptr(%[[S_DECL]]#1
39+
40+
// CHECK: omp.target host_eval({{.*}}) map_entries({{.*}}, %[[S_MAP]] -> %[[S_TARGET_ARG:.*]] : {{.*}}) {
41+
// CHECK: %[[S_DEV_DECL:.*]]:2 = hlfir.declare %[[S_TARGET_ARG]]
42+
// CHECK: omp.teams reduction(@[[OMP_RED]] %[[S_DEV_DECL]]#0 -> %[[RED_TEAMS_ARG:.*]] : !fir.ref<f32>) {
43+
// CHECK: omp.parallel {
44+
// CHECK: omp.distribute {
45+
// CHECK: omp.wsloop reduction(@[[OMP_RED]] %[[RED_TEAMS_ARG]] -> %[[RED_WS_ARG:.*]] : {{.*}}) {
46+
// CHECK: %[[S_WS_DECL:.*]]:2 = hlfir.declare %[[RED_WS_ARG]] {uniq_name = "_QFfooEs"}
47+
// CHECK: %[[S_VAL:.*]] = fir.load %[[S_WS_DECL]]#0
48+
// CHECK: %[[RED_RES:.*]] = arith.addf %[[S_VAL]], %{{.*}} fastmath<contract> : f32
49+
// CHECK: hlfir.assign %[[RED_RES]] to %[[S_WS_DECL]]#0
50+
// CHECK: }
51+
// CHECK: }
52+
// CHECK: }
53+
// CHECK: }

0 commit comments

Comments
 (0)