Skip to content

Commit e681a9f

Browse files
committed
[flang][OpenMP] do concurrent: support local on device
Extends support for mapping `do concurrent` on the device by adding support for `local` specifiers. The changes in this PR map the local variable to the `omp.target` op and uses the mapped value as the `private` clause operand in the nested `omp.parallel` op.
1 parent 6d564c6 commit e681a9f

File tree

3 files changed

+175
-78
lines changed

3 files changed

+175
-78
lines changed

flang/include/flang/Optimizer/Dialect/FIROps.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3894,6 +3894,18 @@ def fir_DoConcurrentLoopOp : fir_Op<"do_concurrent.loop",
38943894
return getReduceVars().size();
38953895
}
38963896

3897+
unsigned getInductionVarsStart() {
3898+
return 0;
3899+
}
3900+
3901+
unsigned getLocalOperandsStart() {
3902+
return getNumInductionVars();
3903+
}
3904+
3905+
unsigned getReduceOperandsStart() {
3906+
return getLocalOperandsStart() + getNumLocalOperands();
3907+
}
3908+
38973909
mlir::Block::BlockArgListType getInductionVars() {
38983910
return getBody()->getArguments().slice(0, getNumInductionVars());
38993911
}

flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp

Lines changed: 114 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,9 @@ void collectLoopLiveIns(fir::DoConcurrentLoopOp loop,
138138

139139
liveIns.push_back(operand->get());
140140
});
141+
142+
for (mlir::Value local : loop.getLocalVars())
143+
liveIns.push_back(local);
141144
}
142145

143146
/// Collects values that are local to a loop: "loop-local values". A loop-local
@@ -252,8 +255,7 @@ class DoConcurrentConversion
252255
.getIsTargetDevice();
253256

254257
mlir::omp::TargetOperands targetClauseOps;
255-
genLoopNestClauseOps(doLoop.getLoc(), rewriter, loop, mapper,
256-
loopNestClauseOps,
258+
genLoopNestClauseOps(doLoop.getLoc(), rewriter, loop, loopNestClauseOps,
257259
isTargetDevice ? nullptr : &targetClauseOps);
258260

259261
LiveInShapeInfoMap liveInShapeInfoMap;
@@ -275,14 +277,13 @@ class DoConcurrentConversion
275277
}
276278

277279
mlir::omp::ParallelOp parallelOp =
278-
genParallelOp(doLoop.getLoc(), rewriter, ivInfos, mapper);
280+
genParallelOp(rewriter, loop, ivInfos, mapper);
279281

280282
// Only set as composite when part of `distribute parallel do`.
281283
parallelOp.setComposite(mapToDevice);
282284

283285
if (!mapToDevice)
284-
genLoopNestClauseOps(doLoop.getLoc(), rewriter, loop, mapper,
285-
loopNestClauseOps);
286+
genLoopNestClauseOps(doLoop.getLoc(), rewriter, loop, loopNestClauseOps);
286287

287288
for (mlir::Value local : locals)
288289
looputils::localizeLoopLocalValue(local, parallelOp.getRegion(),
@@ -291,10 +292,38 @@ class DoConcurrentConversion
291292
if (mapToDevice)
292293
genDistributeOp(doLoop.getLoc(), rewriter).setComposite(/*val=*/true);
293294

294-
mlir::omp::LoopNestOp ompLoopNest =
295+
auto [loopNestOp, wsLoopOp] =
295296
genWsLoopOp(rewriter, loop, mapper, loopNestClauseOps,
296297
/*isComposite=*/mapToDevice);
297298

299+
// `local` region arguments are transferred/cloned from the `do concurrent`
300+
// loop to the loopnest op when the region is cloned above. Instead, these
301+
// region arguments should be on the workshare loop's region.
302+
if (mapToDevice) {
303+
for (auto [parallelArg, loopNestArg] : llvm::zip_equal(
304+
parallelOp.getRegion().getArguments(),
305+
loopNestOp.getRegion().getArguments().slice(
306+
loop.getLocalOperandsStart(), loop.getNumLocalOperands())))
307+
rewriter.replaceAllUsesWith(loopNestArg, parallelArg);
308+
309+
for (auto [wsloopArg, loopNestArg] : llvm::zip_equal(
310+
wsLoopOp.getRegion().getArguments(),
311+
loopNestOp.getRegion().getArguments().slice(
312+
loop.getReduceOperandsStart(), loop.getNumReduceOperands())))
313+
rewriter.replaceAllUsesWith(loopNestArg, wsloopArg);
314+
} else {
315+
for (auto [wsloopArg, loopNestArg] :
316+
llvm::zip_equal(wsLoopOp.getRegion().getArguments(),
317+
loopNestOp.getRegion().getArguments().drop_front(
318+
loopNestClauseOps.loopLowerBounds.size())))
319+
rewriter.replaceAllUsesWith(loopNestArg, wsloopArg);
320+
}
321+
322+
for (unsigned i = 0;
323+
i < loop.getLocalVars().size() + loop.getReduceVars().size(); ++i)
324+
loopNestOp.getRegion().eraseArgument(
325+
loopNestClauseOps.loopLowerBounds.size());
326+
298327
rewriter.setInsertionPoint(doLoop);
299328
fir::FirOpBuilder builder(
300329
rewriter,
@@ -315,7 +344,7 @@ class DoConcurrentConversion
315344
// Mark `unordered` loops that are not perfectly nested to be skipped from
316345
// the legality check of the `ConversionTarget` since we are not interested
317346
// in mapping them to OpenMP.
318-
ompLoopNest->walk([&](fir::DoConcurrentOp doLoop) {
347+
loopNestOp->walk([&](fir::DoConcurrentOp doLoop) {
319348
concurrentLoopsToSkip.insert(doLoop);
320349
});
321350

@@ -371,11 +400,21 @@ class DoConcurrentConversion
371400
llvm::DenseMap<mlir::Value, TargetDeclareShapeCreationInfo>;
372401

373402
mlir::omp::ParallelOp
374-
genParallelOp(mlir::Location loc, mlir::ConversionPatternRewriter &rewriter,
403+
genParallelOp(mlir::ConversionPatternRewriter &rewriter,
404+
fir::DoConcurrentLoopOp loop,
375405
looputils::InductionVariableInfos &ivInfos,
376406
mlir::IRMapping &mapper) const {
377-
auto parallelOp = mlir::omp::ParallelOp::create(rewriter, loc);
378-
rewriter.createBlock(&parallelOp.getRegion());
407+
mlir::omp::ParallelOperands parallelOps;
408+
409+
if (mapToDevice)
410+
genPrivatizers(rewriter, mapper, loop, parallelOps);
411+
412+
mlir::Location loc = loop.getLoc();
413+
auto parallelOp = mlir::omp::ParallelOp::create(rewriter, loc, parallelOps);
414+
Fortran::common::openmp::EntryBlockArgs parallelArgs;
415+
parallelArgs.priv.vars = parallelOps.privateVars;
416+
Fortran::common::openmp::genEntryBlock(rewriter, parallelArgs,
417+
parallelOp.getRegion());
379418
rewriter.setInsertionPoint(mlir::omp::TerminatorOp::create(rewriter, loc));
380419

381420
genLoopNestIndVarAllocs(rewriter, ivInfos, mapper);
@@ -412,7 +451,7 @@ class DoConcurrentConversion
412451

413452
void genLoopNestClauseOps(
414453
mlir::Location loc, mlir::ConversionPatternRewriter &rewriter,
415-
fir::DoConcurrentLoopOp loop, mlir::IRMapping &mapper,
454+
fir::DoConcurrentLoopOp loop,
416455
mlir::omp::LoopNestOperands &loopNestClauseOps,
417456
mlir::omp::TargetOperands *targetClauseOps = nullptr) const {
418457
assert(loopNestClauseOps.loopLowerBounds.empty() &&
@@ -443,59 +482,14 @@ class DoConcurrentConversion
443482
loopNestClauseOps.loopInclusive = rewriter.getUnitAttr();
444483
}
445484

446-
mlir::omp::LoopNestOp
485+
std::pair<mlir::omp::LoopNestOp, mlir::omp::WsloopOp>
447486
genWsLoopOp(mlir::ConversionPatternRewriter &rewriter,
448487
fir::DoConcurrentLoopOp loop, mlir::IRMapping &mapper,
449488
const mlir::omp::LoopNestOperands &clauseOps,
450489
bool isComposite) const {
451490
mlir::omp::WsloopOperands wsloopClauseOps;
452-
453-
auto cloneFIRRegionToOMP = [&rewriter](mlir::Region &firRegion,
454-
mlir::Region &ompRegion) {
455-
if (!firRegion.empty()) {
456-
rewriter.cloneRegionBefore(firRegion, ompRegion, ompRegion.begin());
457-
auto firYield =
458-
mlir::cast<fir::YieldOp>(ompRegion.back().getTerminator());
459-
rewriter.setInsertionPoint(firYield);
460-
mlir::omp::YieldOp::create(rewriter, firYield.getLoc(),
461-
firYield.getOperands());
462-
rewriter.eraseOp(firYield);
463-
}
464-
};
465-
466-
// For `local` (and `local_init`) opernads, emit corresponding `private`
467-
// clauses and attach these clauses to the workshare loop.
468-
if (!loop.getLocalVars().empty())
469-
for (auto [op, sym, arg] : llvm::zip_equal(
470-
loop.getLocalVars(),
471-
loop.getLocalSymsAttr().getAsRange<mlir::SymbolRefAttr>(),
472-
loop.getRegionLocalArgs())) {
473-
auto localizer = moduleSymbolTable.lookup<fir::LocalitySpecifierOp>(
474-
sym.getLeafReference());
475-
if (localizer.getLocalitySpecifierType() ==
476-
fir::LocalitySpecifierType::LocalInit)
477-
TODO(localizer.getLoc(),
478-
"local_init conversion is not supported yet");
479-
480-
mlir::OpBuilder::InsertionGuard guard(rewriter);
481-
rewriter.setInsertionPointAfter(localizer);
482-
483-
auto privatizer = mlir::omp::PrivateClauseOp::create(
484-
rewriter, localizer.getLoc(), sym.getLeafReference().str() + ".omp",
485-
localizer.getTypeAttr().getValue(),
486-
mlir::omp::DataSharingClauseType::Private);
487-
488-
cloneFIRRegionToOMP(localizer.getInitRegion(),
489-
privatizer.getInitRegion());
490-
cloneFIRRegionToOMP(localizer.getDeallocRegion(),
491-
privatizer.getDeallocRegion());
492-
493-
moduleSymbolTable.insert(privatizer);
494-
495-
wsloopClauseOps.privateVars.push_back(op);
496-
wsloopClauseOps.privateSyms.push_back(
497-
mlir::SymbolRefAttr::get(privatizer));
498-
}
491+
if (!mapToDevice)
492+
genPrivatizers(rewriter, mapper, loop, wsloopClauseOps);
499493

500494
if (!loop.getReduceVars().empty()) {
501495
for (auto [op, byRef, sym, arg] : llvm::zip_equal(
@@ -518,15 +512,15 @@ class DoConcurrentConversion
518512
rewriter, firReducer.getLoc(), ompReducerName,
519513
firReducer.getTypeAttr().getValue());
520514

521-
cloneFIRRegionToOMP(firReducer.getAllocRegion(),
515+
cloneFIRRegionToOMP(rewriter, firReducer.getAllocRegion(),
522516
ompReducer.getAllocRegion());
523-
cloneFIRRegionToOMP(firReducer.getInitializerRegion(),
517+
cloneFIRRegionToOMP(rewriter, firReducer.getInitializerRegion(),
524518
ompReducer.getInitializerRegion());
525-
cloneFIRRegionToOMP(firReducer.getReductionRegion(),
519+
cloneFIRRegionToOMP(rewriter, firReducer.getReductionRegion(),
526520
ompReducer.getReductionRegion());
527-
cloneFIRRegionToOMP(firReducer.getAtomicReductionRegion(),
521+
cloneFIRRegionToOMP(rewriter, firReducer.getAtomicReductionRegion(),
528522
ompReducer.getAtomicReductionRegion());
529-
cloneFIRRegionToOMP(firReducer.getCleanupRegion(),
523+
cloneFIRRegionToOMP(rewriter, firReducer.getCleanupRegion(),
530524
ompReducer.getCleanupRegion());
531525
moduleSymbolTable.insert(ompReducer);
532526
}
@@ -558,21 +552,10 @@ class DoConcurrentConversion
558552

559553
rewriter.setInsertionPointToEnd(&loopNestOp.getRegion().back());
560554
mlir::omp::YieldOp::create(rewriter, loop->getLoc());
555+
loop->getParentOfType<mlir::ModuleOp>().print(
556+
llvm::errs(), mlir::OpPrintingFlags().assumeVerified());
561557

562-
// `local` region arguments are transferred/cloned from the `do concurrent`
563-
// loop to the loopnest op when the region is cloned above. Instead, these
564-
// region arguments should be on the workshare loop's region.
565-
for (auto [wsloopArg, loopNestArg] :
566-
llvm::zip_equal(wsloopOp.getRegion().getArguments(),
567-
loopNestOp.getRegion().getArguments().drop_front(
568-
clauseOps.loopLowerBounds.size())))
569-
rewriter.replaceAllUsesWith(loopNestArg, wsloopArg);
570-
571-
for (unsigned i = 0;
572-
i < loop.getLocalVars().size() + loop.getReduceVars().size(); ++i)
573-
loopNestOp.getRegion().eraseArgument(clauseOps.loopLowerBounds.size());
574-
575-
return loopNestOp;
558+
return {loopNestOp, wsloopOp};
576559
}
577560

578561
void genBoundsOps(fir::FirOpBuilder &builder, mlir::Value liveIn,
@@ -813,6 +796,59 @@ class DoConcurrentConversion
813796
return distOp;
814797
}
815798

799+
void cloneFIRRegionToOMP(mlir::ConversionPatternRewriter &rewriter,
800+
mlir::Region &firRegion,
801+
mlir::Region &ompRegion) const {
802+
if (!firRegion.empty()) {
803+
rewriter.cloneRegionBefore(firRegion, ompRegion, ompRegion.begin());
804+
auto firYield =
805+
mlir::cast<fir::YieldOp>(ompRegion.back().getTerminator());
806+
rewriter.setInsertionPoint(firYield);
807+
mlir::omp::YieldOp::create(rewriter, firYield.getLoc(),
808+
firYield.getOperands());
809+
rewriter.eraseOp(firYield);
810+
}
811+
}
812+
813+
void genPrivatizers(mlir::ConversionPatternRewriter &rewriter,
814+
mlir::IRMapping &mapper, fir::DoConcurrentLoopOp loop,
815+
mlir::omp::PrivateClauseOps &privateClauseOps) const {
816+
// For `local` (and `local_init`) operands, emit corresponding `private`
817+
// clauses and attach these clauses to the workshare loop.
818+
if (!loop.getLocalVars().empty())
819+
for (auto [var, sym, arg] : llvm::zip_equal(
820+
loop.getLocalVars(),
821+
loop.getLocalSymsAttr().getAsRange<mlir::SymbolRefAttr>(),
822+
loop.getRegionLocalArgs())) {
823+
auto localizer = moduleSymbolTable.lookup<fir::LocalitySpecifierOp>(
824+
sym.getLeafReference());
825+
if (localizer.getLocalitySpecifierType() ==
826+
fir::LocalitySpecifierType::LocalInit)
827+
TODO(localizer.getLoc(),
828+
"local_init conversion is not supported yet");
829+
830+
mlir::OpBuilder::InsertionGuard guard(rewriter);
831+
rewriter.setInsertionPointAfter(localizer);
832+
833+
auto privatizer = mlir::omp::PrivateClauseOp::create(
834+
rewriter, localizer.getLoc(), sym.getLeafReference().str() + ".omp",
835+
localizer.getTypeAttr().getValue(),
836+
mlir::omp::DataSharingClauseType::Private);
837+
838+
cloneFIRRegionToOMP(rewriter, localizer.getInitRegion(),
839+
privatizer.getInitRegion());
840+
cloneFIRRegionToOMP(rewriter, localizer.getDeallocRegion(),
841+
privatizer.getDeallocRegion());
842+
843+
moduleSymbolTable.insert(privatizer);
844+
845+
privateClauseOps.privateVars.push_back(mapToDevice ? mapper.lookup(var)
846+
: var);
847+
privateClauseOps.privateSyms.push_back(
848+
mlir::SymbolRefAttr::get(privatizer));
849+
}
850+
}
851+
816852
bool mapToDevice;
817853
llvm::DenseSet<fir::DoConcurrentOp> &concurrentLoopsToSkip;
818854
mlir::SymbolTable &moduleSymbolTable;
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
// RUN: fir-opt --omp-do-concurrent-conversion="map-to=device" %s -o - | FileCheck %s
2+
3+
fir.local {type = local} @_QFfooEmy_local_private_f32 : f32
4+
5+
func.func @_QPfoo() {
6+
%0 = fir.dummy_scope : !fir.dscope
7+
%3 = fir.alloca f32 {bindc_name = "my_local", uniq_name = "_QFfooEmy_local"}
8+
%4:2 = hlfir.declare %3 {uniq_name = "_QFfooEmy_local"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
9+
10+
%c1 = arith.constant 1 : index
11+
%c10 = arith.constant 10 : index
12+
13+
fir.do_concurrent {
14+
%7 = fir.alloca i32 {bindc_name = "i"}
15+
%8:2 = hlfir.declare %7 {uniq_name = "_QFfooEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
16+
17+
fir.do_concurrent.loop (%arg0) = (%c1) to (%c10) step (%c1) local(@_QFfooEmy_local_private_f32 %4#0 -> %arg1 : !fir.ref<f32>) {
18+
%9 = fir.convert %arg0 : (index) -> i32
19+
fir.store %9 to %8#0 : !fir.ref<i32>
20+
%10:2 = hlfir.declare %arg1 {uniq_name = "_QFfooEmy_local"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
21+
%cst = arith.constant 4.200000e+01 : f32
22+
hlfir.assign %cst to %10#0 : f32, !fir.ref<f32>
23+
}
24+
}
25+
return
26+
}
27+
28+
// CHECK: omp.private {type = private} @[[OMP_PRIVATIZER:.*.omp]] : f32
29+
30+
// CHECK: %[[LOCAL_DECL:.*]]:2 = hlfir.declare %{{.*}} {uniq_name = "{{.*}}my_local"}
31+
// CHECK: %[[LOCAL_MAP:.*]] = omp.map.info var_ptr(%[[LOCAL_DECL]]#1 : {{.*}})
32+
33+
// CHECK: omp.target host_eval({{.*}}) map_entries({{.*}}, %[[LOCAL_MAP]] -> %[[LOCAL_MAP_ARG:.*]] : {{.*}}) {
34+
// CHECK: %[[LOCAL_DEV_DECL:.*]]:2 = hlfir.declare %[[LOCAL_MAP_ARG]] {uniq_name = "_QFfooEmy_local"}
35+
36+
// CHECK: omp.teams {
37+
// CHECK: omp.parallel private(@[[OMP_PRIVATIZER]] %[[LOCAL_DEV_DECL]]#0 -> %[[LOCAL_PRIV_ARG:.*]] : {{.*}}) {
38+
// CHECK: omp.distribute {
39+
// CHECK: omp.wsloop {
40+
// CHECK: omp.loop_nest {{.*}} {
41+
// CHECK: %[[LOCAL_LOOP_DECL:.*]]:2 = hlfir.declare %[[LOCAL_PRIV_ARG]] {uniq_name = "_QFfooEmy_local"}
42+
// CHECK: hlfir.assign %{{.*}} to %[[LOCAL_LOOP_DECL]]#0
43+
// CHECK: omp.yield
44+
// CHECK: }
45+
// CHECK: }
46+
// CHECK: }
47+
// CHECK: }
48+
// CHECK: }
49+
// CHECK: }

0 commit comments

Comments
 (0)