Skip to content

Commit 9008c44

Browse files
authored
[flang][OpenMP] do concurrent: support local on device (#157638)
Extends support for mapping `do concurrent` on the device by adding support for `local` specifiers. The changes in this PR map the local variable to the `omp.target` op and uses the mapped value as the `private` clause operand in the nested `omp.parallel` op. - #155754 - #155987 - #155992 - #155993 - #157638 ◀️ - #156610 - #156837
1 parent c69a70b commit 9008c44

File tree

3 files changed

+183
-78
lines changed

3 files changed

+183
-78
lines changed

flang/include/flang/Optimizer/Dialect/FIROps.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3894,6 +3894,18 @@ def fir_DoConcurrentLoopOp : fir_Op<"do_concurrent.loop",
38943894
return getReduceVars().size();
38953895
}
38963896

3897+
unsigned getInductionVarsStart() {
3898+
return 0;
3899+
}
3900+
3901+
unsigned getLocalOperandsStart() {
3902+
return getNumInductionVars();
3903+
}
3904+
3905+
unsigned getReduceOperandsStart() {
3906+
return getLocalOperandsStart() + getNumLocalOperands();
3907+
}
3908+
38973909
mlir::Block::BlockArgListType getInductionVars() {
38983910
return getBody()->getArguments().slice(0, getNumInductionVars());
38993911
}

flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp

Lines changed: 122 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,9 @@ void collectLoopLiveIns(fir::DoConcurrentLoopOp loop,
138138

139139
liveIns.push_back(operand->get());
140140
});
141+
142+
for (mlir::Value local : loop.getLocalVars())
143+
liveIns.push_back(local);
141144
}
142145

143146
/// Collects values that are local to a loop: "loop-local values". A loop-local
@@ -298,8 +301,7 @@ class DoConcurrentConversion
298301
.getIsTargetDevice();
299302

300303
mlir::omp::TargetOperands targetClauseOps;
301-
genLoopNestClauseOps(doLoop.getLoc(), rewriter, loop, mapper,
302-
loopNestClauseOps,
304+
genLoopNestClauseOps(doLoop.getLoc(), rewriter, loop, loopNestClauseOps,
303305
isTargetDevice ? nullptr : &targetClauseOps);
304306

305307
LiveInShapeInfoMap liveInShapeInfoMap;
@@ -321,14 +323,13 @@ class DoConcurrentConversion
321323
}
322324

323325
mlir::omp::ParallelOp parallelOp =
324-
genParallelOp(doLoop.getLoc(), rewriter, ivInfos, mapper);
326+
genParallelOp(rewriter, loop, ivInfos, mapper);
325327

326328
// Only set as composite when part of `distribute parallel do`.
327329
parallelOp.setComposite(mapToDevice);
328330

329331
if (!mapToDevice)
330-
genLoopNestClauseOps(doLoop.getLoc(), rewriter, loop, mapper,
331-
loopNestClauseOps);
332+
genLoopNestClauseOps(doLoop.getLoc(), rewriter, loop, loopNestClauseOps);
332333

333334
for (mlir::Value local : locals)
334335
looputils::localizeLoopLocalValue(local, parallelOp.getRegion(),
@@ -337,10 +338,38 @@ class DoConcurrentConversion
337338
if (mapToDevice)
338339
genDistributeOp(doLoop.getLoc(), rewriter).setComposite(/*val=*/true);
339340

340-
mlir::omp::LoopNestOp ompLoopNest =
341+
auto [loopNestOp, wsLoopOp] =
341342
genWsLoopOp(rewriter, loop, mapper, loopNestClauseOps,
342343
/*isComposite=*/mapToDevice);
343344

345+
// `local` region arguments are transferred/cloned from the `do concurrent`
346+
// loop to the loopnest op when the region is cloned above. Instead, these
347+
// region arguments should be on the workshare loop's region.
348+
if (mapToDevice) {
349+
for (auto [parallelArg, loopNestArg] : llvm::zip_equal(
350+
parallelOp.getRegion().getArguments(),
351+
loopNestOp.getRegion().getArguments().slice(
352+
loop.getLocalOperandsStart(), loop.getNumLocalOperands())))
353+
rewriter.replaceAllUsesWith(loopNestArg, parallelArg);
354+
355+
for (auto [wsloopArg, loopNestArg] : llvm::zip_equal(
356+
wsLoopOp.getRegion().getArguments(),
357+
loopNestOp.getRegion().getArguments().slice(
358+
loop.getReduceOperandsStart(), loop.getNumReduceOperands())))
359+
rewriter.replaceAllUsesWith(loopNestArg, wsloopArg);
360+
} else {
361+
for (auto [wsloopArg, loopNestArg] :
362+
llvm::zip_equal(wsLoopOp.getRegion().getArguments(),
363+
loopNestOp.getRegion().getArguments().drop_front(
364+
loopNestClauseOps.loopLowerBounds.size())))
365+
rewriter.replaceAllUsesWith(loopNestArg, wsloopArg);
366+
}
367+
368+
for (unsigned i = 0;
369+
i < loop.getLocalVars().size() + loop.getReduceVars().size(); ++i)
370+
loopNestOp.getRegion().eraseArgument(
371+
loopNestClauseOps.loopLowerBounds.size());
372+
344373
rewriter.setInsertionPoint(doLoop);
345374
fir::FirOpBuilder builder(
346375
rewriter,
@@ -361,7 +390,7 @@ class DoConcurrentConversion
361390
// Mark `unordered` loops that are not perfectly nested to be skipped from
362391
// the legality check of the `ConversionTarget` since we are not interested
363392
// in mapping them to OpenMP.
364-
ompLoopNest->walk([&](fir::DoConcurrentOp doLoop) {
393+
loopNestOp->walk([&](fir::DoConcurrentOp doLoop) {
365394
concurrentLoopsToSkip.insert(doLoop);
366395
});
367396

@@ -372,11 +401,21 @@ class DoConcurrentConversion
372401

373402
private:
374403
mlir::omp::ParallelOp
375-
genParallelOp(mlir::Location loc, mlir::ConversionPatternRewriter &rewriter,
404+
genParallelOp(mlir::ConversionPatternRewriter &rewriter,
405+
fir::DoConcurrentLoopOp loop,
376406
looputils::InductionVariableInfos &ivInfos,
377407
mlir::IRMapping &mapper) const {
378-
auto parallelOp = mlir::omp::ParallelOp::create(rewriter, loc);
379-
rewriter.createBlock(&parallelOp.getRegion());
408+
mlir::omp::ParallelOperands parallelOps;
409+
410+
if (mapToDevice)
411+
genPrivatizers(rewriter, mapper, loop, parallelOps);
412+
413+
mlir::Location loc = loop.getLoc();
414+
auto parallelOp = mlir::omp::ParallelOp::create(rewriter, loc, parallelOps);
415+
Fortran::common::openmp::EntryBlockArgs parallelArgs;
416+
parallelArgs.priv.vars = parallelOps.privateVars;
417+
Fortran::common::openmp::genEntryBlock(rewriter, parallelArgs,
418+
parallelOp.getRegion());
380419
rewriter.setInsertionPoint(mlir::omp::TerminatorOp::create(rewriter, loc));
381420

382421
genLoopNestIndVarAllocs(rewriter, ivInfos, mapper);
@@ -413,7 +452,7 @@ class DoConcurrentConversion
413452

414453
void genLoopNestClauseOps(
415454
mlir::Location loc, mlir::ConversionPatternRewriter &rewriter,
416-
fir::DoConcurrentLoopOp loop, mlir::IRMapping &mapper,
455+
fir::DoConcurrentLoopOp loop,
417456
mlir::omp::LoopNestOperands &loopNestClauseOps,
418457
mlir::omp::TargetOperands *targetClauseOps = nullptr) const {
419458
assert(loopNestClauseOps.loopLowerBounds.empty() &&
@@ -444,59 +483,14 @@ class DoConcurrentConversion
444483
loopNestClauseOps.loopInclusive = rewriter.getUnitAttr();
445484
}
446485

447-
mlir::omp::LoopNestOp
486+
std::pair<mlir::omp::LoopNestOp, mlir::omp::WsloopOp>
448487
genWsLoopOp(mlir::ConversionPatternRewriter &rewriter,
449488
fir::DoConcurrentLoopOp loop, mlir::IRMapping &mapper,
450489
const mlir::omp::LoopNestOperands &clauseOps,
451490
bool isComposite) const {
452491
mlir::omp::WsloopOperands wsloopClauseOps;
453-
454-
auto cloneFIRRegionToOMP = [&rewriter](mlir::Region &firRegion,
455-
mlir::Region &ompRegion) {
456-
if (!firRegion.empty()) {
457-
rewriter.cloneRegionBefore(firRegion, ompRegion, ompRegion.begin());
458-
auto firYield =
459-
mlir::cast<fir::YieldOp>(ompRegion.back().getTerminator());
460-
rewriter.setInsertionPoint(firYield);
461-
mlir::omp::YieldOp::create(rewriter, firYield.getLoc(),
462-
firYield.getOperands());
463-
rewriter.eraseOp(firYield);
464-
}
465-
};
466-
467-
// For `local` (and `local_init`) opernads, emit corresponding `private`
468-
// clauses and attach these clauses to the workshare loop.
469-
if (!loop.getLocalVars().empty())
470-
for (auto [op, sym, arg] : llvm::zip_equal(
471-
loop.getLocalVars(),
472-
loop.getLocalSymsAttr().getAsRange<mlir::SymbolRefAttr>(),
473-
loop.getRegionLocalArgs())) {
474-
auto localizer = moduleSymbolTable.lookup<fir::LocalitySpecifierOp>(
475-
sym.getLeafReference());
476-
if (localizer.getLocalitySpecifierType() ==
477-
fir::LocalitySpecifierType::LocalInit)
478-
TODO(localizer.getLoc(),
479-
"local_init conversion is not supported yet");
480-
481-
mlir::OpBuilder::InsertionGuard guard(rewriter);
482-
rewriter.setInsertionPointAfter(localizer);
483-
484-
auto privatizer = mlir::omp::PrivateClauseOp::create(
485-
rewriter, localizer.getLoc(), sym.getLeafReference().str() + ".omp",
486-
localizer.getTypeAttr().getValue(),
487-
mlir::omp::DataSharingClauseType::Private);
488-
489-
cloneFIRRegionToOMP(localizer.getInitRegion(),
490-
privatizer.getInitRegion());
491-
cloneFIRRegionToOMP(localizer.getDeallocRegion(),
492-
privatizer.getDeallocRegion());
493-
494-
moduleSymbolTable.insert(privatizer);
495-
496-
wsloopClauseOps.privateVars.push_back(op);
497-
wsloopClauseOps.privateSyms.push_back(
498-
mlir::SymbolRefAttr::get(privatizer));
499-
}
492+
if (!mapToDevice)
493+
genPrivatizers(rewriter, mapper, loop, wsloopClauseOps);
500494

501495
if (!loop.getReduceVars().empty()) {
502496
for (auto [op, byRef, sym, arg] : llvm::zip_equal(
@@ -519,15 +513,15 @@ class DoConcurrentConversion
519513
rewriter, firReducer.getLoc(), ompReducerName,
520514
firReducer.getTypeAttr().getValue());
521515

522-
cloneFIRRegionToOMP(firReducer.getAllocRegion(),
516+
cloneFIRRegionToOMP(rewriter, firReducer.getAllocRegion(),
523517
ompReducer.getAllocRegion());
524-
cloneFIRRegionToOMP(firReducer.getInitializerRegion(),
518+
cloneFIRRegionToOMP(rewriter, firReducer.getInitializerRegion(),
525519
ompReducer.getInitializerRegion());
526-
cloneFIRRegionToOMP(firReducer.getReductionRegion(),
520+
cloneFIRRegionToOMP(rewriter, firReducer.getReductionRegion(),
527521
ompReducer.getReductionRegion());
528-
cloneFIRRegionToOMP(firReducer.getAtomicReductionRegion(),
522+
cloneFIRRegionToOMP(rewriter, firReducer.getAtomicReductionRegion(),
529523
ompReducer.getAtomicReductionRegion());
530-
cloneFIRRegionToOMP(firReducer.getCleanupRegion(),
524+
cloneFIRRegionToOMP(rewriter, firReducer.getCleanupRegion(),
531525
ompReducer.getCleanupRegion());
532526
moduleSymbolTable.insert(ompReducer);
533527
}
@@ -559,21 +553,10 @@ class DoConcurrentConversion
559553

560554
rewriter.setInsertionPointToEnd(&loopNestOp.getRegion().back());
561555
mlir::omp::YieldOp::create(rewriter, loop->getLoc());
556+
loop->getParentOfType<mlir::ModuleOp>().print(
557+
llvm::errs(), mlir::OpPrintingFlags().assumeVerified());
562558

563-
// `local` region arguments are transferred/cloned from the `do concurrent`
564-
// loop to the loopnest op when the region is cloned above. Instead, these
565-
// region arguments should be on the workshare loop's region.
566-
for (auto [wsloopArg, loopNestArg] :
567-
llvm::zip_equal(wsloopOp.getRegion().getArguments(),
568-
loopNestOp.getRegion().getArguments().drop_front(
569-
clauseOps.loopLowerBounds.size())))
570-
rewriter.replaceAllUsesWith(loopNestArg, wsloopArg);
571-
572-
for (unsigned i = 0;
573-
i < loop.getLocalVars().size() + loop.getReduceVars().size(); ++i)
574-
loopNestOp.getRegion().eraseArgument(clauseOps.loopLowerBounds.size());
575-
576-
return loopNestOp;
559+
return {loopNestOp, wsloopOp};
577560
}
578561

579562
void genBoundsOps(fir::FirOpBuilder &builder, mlir::Value liveIn,
@@ -817,6 +800,67 @@ class DoConcurrentConversion
817800
return distOp;
818801
}
819802

803+
void cloneFIRRegionToOMP(mlir::ConversionPatternRewriter &rewriter,
804+
mlir::Region &firRegion,
805+
mlir::Region &ompRegion) const {
806+
if (!firRegion.empty()) {
807+
rewriter.cloneRegionBefore(firRegion, ompRegion, ompRegion.begin());
808+
auto firYield =
809+
mlir::cast<fir::YieldOp>(ompRegion.back().getTerminator());
810+
rewriter.setInsertionPoint(firYield);
811+
mlir::omp::YieldOp::create(rewriter, firYield.getLoc(),
812+
firYield.getOperands());
813+
rewriter.eraseOp(firYield);
814+
}
815+
}
816+
817+
/// Generate bodies of OpenMP privatizers by cloning the bodies of FIR
818+
/// privatizers.
819+
///
820+
/// \param [in] rewriter - used to driver IR generation for privatizers.
821+
/// \param [in] mapper - value mapping from FIR to OpenMP constructs.
822+
/// \param [in] loop - FIR loop to convert its localizers.
823+
///
824+
/// \param [out] privateClauseOps - OpenMP privatizers to gen their bodies.
825+
void genPrivatizers(mlir::ConversionPatternRewriter &rewriter,
826+
mlir::IRMapping &mapper, fir::DoConcurrentLoopOp loop,
827+
mlir::omp::PrivateClauseOps &privateClauseOps) const {
828+
// For `local` (and `local_init`) operands, emit corresponding `private`
829+
// clauses and attach these clauses to the workshare loop.
830+
if (!loop.getLocalVars().empty())
831+
for (auto [var, sym, arg] : llvm::zip_equal(
832+
loop.getLocalVars(),
833+
loop.getLocalSymsAttr().getAsRange<mlir::SymbolRefAttr>(),
834+
loop.getRegionLocalArgs())) {
835+
auto localizer = moduleSymbolTable.lookup<fir::LocalitySpecifierOp>(
836+
sym.getLeafReference());
837+
if (localizer.getLocalitySpecifierType() ==
838+
fir::LocalitySpecifierType::LocalInit)
839+
TODO(localizer.getLoc(),
840+
"local_init conversion is not supported yet");
841+
842+
mlir::OpBuilder::InsertionGuard guard(rewriter);
843+
rewriter.setInsertionPointAfter(localizer);
844+
845+
auto privatizer = mlir::omp::PrivateClauseOp::create(
846+
rewriter, localizer.getLoc(), sym.getLeafReference().str() + ".omp",
847+
localizer.getTypeAttr().getValue(),
848+
mlir::omp::DataSharingClauseType::Private);
849+
850+
cloneFIRRegionToOMP(rewriter, localizer.getInitRegion(),
851+
privatizer.getInitRegion());
852+
cloneFIRRegionToOMP(rewriter, localizer.getDeallocRegion(),
853+
privatizer.getDeallocRegion());
854+
855+
moduleSymbolTable.insert(privatizer);
856+
857+
privateClauseOps.privateVars.push_back(mapToDevice ? mapper.lookup(var)
858+
: var);
859+
privateClauseOps.privateSyms.push_back(
860+
mlir::SymbolRefAttr::get(privatizer));
861+
}
862+
}
863+
820864
bool mapToDevice;
821865
llvm::DenseSet<fir::DoConcurrentOp> &concurrentLoopsToSkip;
822866
mlir::SymbolTable &moduleSymbolTable;
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
// RUN: fir-opt --omp-do-concurrent-conversion="map-to=device" %s -o - | FileCheck %s
2+
3+
fir.local {type = local} @_QFfooEmy_local_private_f32 : f32
4+
5+
func.func @_QPfoo() {
6+
%0 = fir.dummy_scope : !fir.dscope
7+
%3 = fir.alloca f32 {bindc_name = "my_local", uniq_name = "_QFfooEmy_local"}
8+
%4:2 = hlfir.declare %3 {uniq_name = "_QFfooEmy_local"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
9+
10+
%c1 = arith.constant 1 : index
11+
%c10 = arith.constant 10 : index
12+
13+
fir.do_concurrent {
14+
%7 = fir.alloca i32 {bindc_name = "i"}
15+
%8:2 = hlfir.declare %7 {uniq_name = "_QFfooEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
16+
17+
fir.do_concurrent.loop (%arg0) = (%c1) to (%c10) step (%c1) local(@_QFfooEmy_local_private_f32 %4#0 -> %arg1 : !fir.ref<f32>) {
18+
%9 = fir.convert %arg0 : (index) -> i32
19+
fir.store %9 to %8#0 : !fir.ref<i32>
20+
%10:2 = hlfir.declare %arg1 {uniq_name = "_QFfooEmy_local"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
21+
%cst = arith.constant 4.200000e+01 : f32
22+
hlfir.assign %cst to %10#0 : f32, !fir.ref<f32>
23+
}
24+
}
25+
return
26+
}
27+
28+
// CHECK: omp.private {type = private} @[[OMP_PRIVATIZER:.*.omp]] : f32
29+
30+
// CHECK: %[[LOCAL_DECL:.*]]:2 = hlfir.declare %{{.*}} {uniq_name = "{{.*}}my_local"}
31+
// CHECK: %[[LOCAL_MAP:.*]] = omp.map.info var_ptr(%[[LOCAL_DECL]]#1 : {{.*}})
32+
33+
// CHECK: omp.target host_eval({{.*}}) map_entries({{.*}}, %[[LOCAL_MAP]] -> %[[LOCAL_MAP_ARG:.*]] : {{.*}}) {
34+
// CHECK: %[[LOCAL_DEV_DECL:.*]]:2 = hlfir.declare %[[LOCAL_MAP_ARG]] {uniq_name = "_QFfooEmy_local"}
35+
36+
// CHECK: omp.teams {
37+
// CHECK: omp.parallel private(@[[OMP_PRIVATIZER]] %[[LOCAL_DEV_DECL]]#0 -> %[[LOCAL_PRIV_ARG:.*]] : {{.*}}) {
38+
// CHECK: omp.distribute {
39+
// CHECK: omp.wsloop {
40+
// CHECK: omp.loop_nest {{.*}} {
41+
// CHECK: %[[LOCAL_LOOP_DECL:.*]]:2 = hlfir.declare %[[LOCAL_PRIV_ARG]] {uniq_name = "_QFfooEmy_local"}
42+
// CHECK: hlfir.assign %{{.*}} to %[[LOCAL_LOOP_DECL]]#0
43+
// CHECK: omp.yield
44+
// CHECK: }
45+
// CHECK: }
46+
// CHECK: }
47+
// CHECK: }
48+
// CHECK: }
49+
// CHECK: }

0 commit comments

Comments
 (0)