Skip to content

Commit 9b75446

Browse files
authored
[flang][OpenMP] do concurrent: support reduce on device (#156610)
Extends `do concurrent` to OpenMP device mapping by adding support for mapping `reduce` specifiers to omp `reduction` clauses. The changes attach 2 `reduction` clauses to the mapped OpenMP construct: one on the `teams` part of the construct and one on the `wloop` part. - #155754 - #155987 - #155992 - #155993 - #157638 - #156610 ◀️ - #156837
1 parent 15f05dc commit 9b75446

File tree

2 files changed

+121
-49
lines changed

2 files changed

+121
-49
lines changed

flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp

Lines changed: 68 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,9 @@ void collectLoopLiveIns(fir::DoConcurrentLoopOp loop,
141141

142142
for (mlir::Value local : loop.getLocalVars())
143143
liveIns.push_back(local);
144+
145+
for (mlir::Value reduce : loop.getReduceVars())
146+
liveIns.push_back(reduce);
144147
}
145148

146149
/// Collects values that are local to a loop: "loop-local values". A loop-local
@@ -319,7 +322,7 @@ class DoConcurrentConversion
319322
targetOp =
320323
genTargetOp(doLoop.getLoc(), rewriter, mapper, loopNestLiveIns,
321324
targetClauseOps, loopNestClauseOps, liveInShapeInfoMap);
322-
genTeamsOp(doLoop.getLoc(), rewriter);
325+
genTeamsOp(rewriter, loop, mapper);
323326
}
324327

325328
mlir::omp::ParallelOp parallelOp =
@@ -492,46 +495,7 @@ class DoConcurrentConversion
492495
if (!mapToDevice)
493496
genPrivatizers(rewriter, mapper, loop, wsloopClauseOps);
494497

495-
if (!loop.getReduceVars().empty()) {
496-
for (auto [op, byRef, sym, arg] : llvm::zip_equal(
497-
loop.getReduceVars(), loop.getReduceByrefAttr().asArrayRef(),
498-
loop.getReduceSymsAttr().getAsRange<mlir::SymbolRefAttr>(),
499-
loop.getRegionReduceArgs())) {
500-
auto firReducer = moduleSymbolTable.lookup<fir::DeclareReductionOp>(
501-
sym.getLeafReference());
502-
503-
mlir::OpBuilder::InsertionGuard guard(rewriter);
504-
rewriter.setInsertionPointAfter(firReducer);
505-
std::string ompReducerName = sym.getLeafReference().str() + ".omp";
506-
507-
auto ompReducer =
508-
moduleSymbolTable.lookup<mlir::omp::DeclareReductionOp>(
509-
rewriter.getStringAttr(ompReducerName));
510-
511-
if (!ompReducer) {
512-
ompReducer = mlir::omp::DeclareReductionOp::create(
513-
rewriter, firReducer.getLoc(), ompReducerName,
514-
firReducer.getTypeAttr().getValue());
515-
516-
cloneFIRRegionToOMP(rewriter, firReducer.getAllocRegion(),
517-
ompReducer.getAllocRegion());
518-
cloneFIRRegionToOMP(rewriter, firReducer.getInitializerRegion(),
519-
ompReducer.getInitializerRegion());
520-
cloneFIRRegionToOMP(rewriter, firReducer.getReductionRegion(),
521-
ompReducer.getReductionRegion());
522-
cloneFIRRegionToOMP(rewriter, firReducer.getAtomicReductionRegion(),
523-
ompReducer.getAtomicReductionRegion());
524-
cloneFIRRegionToOMP(rewriter, firReducer.getCleanupRegion(),
525-
ompReducer.getCleanupRegion());
526-
moduleSymbolTable.insert(ompReducer);
527-
}
528-
529-
wsloopClauseOps.reductionVars.push_back(op);
530-
wsloopClauseOps.reductionByref.push_back(byRef);
531-
wsloopClauseOps.reductionSyms.push_back(
532-
mlir::SymbolRefAttr::get(ompReducer));
533-
}
534-
}
498+
genReductions(rewriter, mapper, loop, wsloopClauseOps);
535499

536500
auto wsloopOp =
537501
mlir::omp::WsloopOp::create(rewriter, loop.getLoc(), wsloopClauseOps);
@@ -553,8 +517,6 @@ class DoConcurrentConversion
553517

554518
rewriter.setInsertionPointToEnd(&loopNestOp.getRegion().back());
555519
mlir::omp::YieldOp::create(rewriter, loop->getLoc());
556-
loop->getParentOfType<mlir::ModuleOp>().print(
557-
llvm::errs(), mlir::OpPrintingFlags().assumeVerified());
558520

559521
return {loopNestOp, wsloopOp};
560522
}
@@ -778,15 +740,26 @@ class DoConcurrentConversion
778740
liveInName, shape);
779741
}
780742

781-
mlir::omp::TeamsOp
782-
genTeamsOp(mlir::Location loc,
783-
mlir::ConversionPatternRewriter &rewriter) const {
784-
auto teamsOp = rewriter.create<mlir::omp::TeamsOp>(
785-
loc, /*clauses=*/mlir::omp::TeamsOperands{});
743+
mlir::omp::TeamsOp genTeamsOp(mlir::ConversionPatternRewriter &rewriter,
744+
fir::DoConcurrentLoopOp loop,
745+
mlir::IRMapping &mapper) const {
746+
mlir::omp::TeamsOperands teamsOps;
747+
genReductions(rewriter, mapper, loop, teamsOps);
748+
749+
mlir::Location loc = loop.getLoc();
750+
auto teamsOp = rewriter.create<mlir::omp::TeamsOp>(loc, teamsOps);
751+
Fortran::common::openmp::EntryBlockArgs teamsArgs;
752+
teamsArgs.reduction.vars = teamsOps.reductionVars;
753+
Fortran::common::openmp::genEntryBlock(rewriter, teamsArgs,
754+
teamsOp.getRegion());
786755

787-
rewriter.createBlock(&teamsOp.getRegion());
788756
rewriter.setInsertionPoint(rewriter.create<mlir::omp::TerminatorOp>(loc));
789757

758+
for (auto [loopVar, teamsArg] : llvm::zip_equal(
759+
loop.getReduceVars(), teamsOp.getRegion().getArguments())) {
760+
mapper.map(loopVar, teamsArg);
761+
}
762+
790763
return teamsOp;
791764
}
792765

@@ -861,6 +834,52 @@ class DoConcurrentConversion
861834
}
862835
}
863836

837+
void genReductions(mlir::ConversionPatternRewriter &rewriter,
838+
mlir::IRMapping &mapper, fir::DoConcurrentLoopOp loop,
839+
mlir::omp::ReductionClauseOps &reductionClauseOps) const {
840+
if (!loop.getReduceVars().empty()) {
841+
for (auto [var, byRef, sym, arg] : llvm::zip_equal(
842+
loop.getReduceVars(), loop.getReduceByrefAttr().asArrayRef(),
843+
loop.getReduceSymsAttr().getAsRange<mlir::SymbolRefAttr>(),
844+
loop.getRegionReduceArgs())) {
845+
auto firReducer = moduleSymbolTable.lookup<fir::DeclareReductionOp>(
846+
sym.getLeafReference());
847+
848+
mlir::OpBuilder::InsertionGuard guard(rewriter);
849+
rewriter.setInsertionPointAfter(firReducer);
850+
std::string ompReducerName = sym.getLeafReference().str() + ".omp";
851+
852+
auto ompReducer =
853+
moduleSymbolTable.lookup<mlir::omp::DeclareReductionOp>(
854+
rewriter.getStringAttr(ompReducerName));
855+
856+
if (!ompReducer) {
857+
ompReducer = mlir::omp::DeclareReductionOp::create(
858+
rewriter, firReducer.getLoc(), ompReducerName,
859+
firReducer.getTypeAttr().getValue());
860+
861+
cloneFIRRegionToOMP(rewriter, firReducer.getAllocRegion(),
862+
ompReducer.getAllocRegion());
863+
cloneFIRRegionToOMP(rewriter, firReducer.getInitializerRegion(),
864+
ompReducer.getInitializerRegion());
865+
cloneFIRRegionToOMP(rewriter, firReducer.getReductionRegion(),
866+
ompReducer.getReductionRegion());
867+
cloneFIRRegionToOMP(rewriter, firReducer.getAtomicReductionRegion(),
868+
ompReducer.getAtomicReductionRegion());
869+
cloneFIRRegionToOMP(rewriter, firReducer.getCleanupRegion(),
870+
ompReducer.getCleanupRegion());
871+
moduleSymbolTable.insert(ompReducer);
872+
}
873+
874+
reductionClauseOps.reductionVars.push_back(
875+
mapToDevice ? mapper.lookup(var) : var);
876+
reductionClauseOps.reductionByref.push_back(byRef);
877+
reductionClauseOps.reductionSyms.push_back(
878+
mlir::SymbolRefAttr::get(ompReducer));
879+
}
880+
}
881+
}
882+
864883
bool mapToDevice;
865884
llvm::DenseSet<fir::DoConcurrentOp> &concurrentLoopsToSkip;
866885
mlir::SymbolTable &moduleSymbolTable;
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
// RUN: fir-opt --omp-do-concurrent-conversion="map-to=device" %s -o - | FileCheck %s
2+
3+
fir.declare_reduction @add_reduction_f32 : f32 init {
4+
^bb0(%arg0: f32):
5+
%cst = arith.constant 0.000000e+00 : f32
6+
fir.yield(%cst : f32)
7+
} combiner {
8+
^bb0(%arg0: f32, %arg1: f32):
9+
%0 = arith.addf %arg0, %arg1 fastmath<contract> : f32
10+
fir.yield(%0 : f32)
11+
}
12+
13+
func.func @_QPfoo() {
14+
%0 = fir.dummy_scope : !fir.dscope
15+
%3 = fir.alloca f32 {bindc_name = "s", uniq_name = "_QFfooEs"}
16+
%4:2 = hlfir.declare %3 {uniq_name = "_QFfooEs"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
17+
%c1 = arith.constant 1 : index
18+
%c10 = arith.constant 1 : index
19+
fir.do_concurrent {
20+
%7 = fir.alloca i32 {bindc_name = "i"}
21+
%8:2 = hlfir.declare %7 {uniq_name = "_QFfooEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
22+
fir.do_concurrent.loop (%arg0) = (%c1) to (%c10) step (%c1) reduce(@add_reduction_f32 #fir.reduce_attr<add> %4#0 -> %arg1 : !fir.ref<f32>) {
23+
%9 = fir.convert %arg0 : (index) -> i32
24+
fir.store %9 to %8#0 : !fir.ref<i32>
25+
%10:2 = hlfir.declare %arg1 {uniq_name = "_QFfooEs"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
26+
%11 = fir.load %10#0 : !fir.ref<f32>
27+
%cst = arith.constant 1.000000e+00 : f32
28+
%12 = arith.addf %11, %cst fastmath<contract> : f32
29+
hlfir.assign %12 to %10#0 : f32, !fir.ref<f32>
30+
}
31+
}
32+
return
33+
}
34+
35+
// CHECK: omp.declare_reduction @[[OMP_RED:.*.omp]] : f32
36+
37+
// CHECK: %[[S_DECL:.*]]:2 = hlfir.declare %6 {uniq_name = "_QFfooEs"}
38+
// CHECK: %[[S_MAP:.*]] = omp.map.info var_ptr(%[[S_DECL]]#1
39+
40+
// CHECK: omp.target host_eval({{.*}}) map_entries({{.*}}, %[[S_MAP]] -> %[[S_TARGET_ARG:.*]] : {{.*}}) {
41+
// CHECK: %[[S_DEV_DECL:.*]]:2 = hlfir.declare %[[S_TARGET_ARG]]
42+
// CHECK: omp.teams reduction(@[[OMP_RED]] %[[S_DEV_DECL]]#0 -> %[[RED_TEAMS_ARG:.*]] : !fir.ref<f32>) {
43+
// CHECK: omp.parallel {
44+
// CHECK: omp.distribute {
45+
// CHECK: omp.wsloop reduction(@[[OMP_RED]] %[[RED_TEAMS_ARG]] -> %[[RED_WS_ARG:.*]] : {{.*}}) {
46+
// CHECK: %[[S_WS_DECL:.*]]:2 = hlfir.declare %[[RED_WS_ARG]] {uniq_name = "_QFfooEs"}
47+
// CHECK: %[[S_VAL:.*]] = fir.load %[[S_WS_DECL]]#0
48+
// CHECK: %[[RED_RES:.*]] = arith.addf %[[S_VAL]], %{{.*}} fastmath<contract> : f32
49+
// CHECK: hlfir.assign %[[RED_RES]] to %[[S_WS_DECL]]#0
50+
// CHECK: }
51+
// CHECK: }
52+
// CHECK: }
53+
// CHECK: }

0 commit comments

Comments
 (0)