Skip to content

Commit 78fc5ed

Browse files
committed
[flang][OpenMP] do concurrent: support local on device
Extends support for mapping `do concurrent` on the device by adding support for `local` specifiers. The changes in this PR map the local variable to the `omp.target` op and uses the mapped value as the `private` clause operand in the nested `omp.parallel` op.
1 parent 2fd2022 commit 78fc5ed

File tree

3 files changed

+175
-78
lines changed

3 files changed

+175
-78
lines changed

flang/include/flang/Optimizer/Dialect/FIROps.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3894,6 +3894,18 @@ def fir_DoConcurrentLoopOp : fir_Op<"do_concurrent.loop",
38943894
return getReduceVars().size();
38953895
}
38963896

3897+
unsigned getInductionVarsStart() {
3898+
return 0;
3899+
}
3900+
3901+
unsigned getLocalOperandsStart() {
3902+
return getNumInductionVars();
3903+
}
3904+
3905+
unsigned getReduceOperandsStart() {
3906+
return getLocalOperandsStart() + getNumLocalOperands();
3907+
}
3908+
38973909
mlir::Block::BlockArgListType getInductionVars() {
38983910
return getBody()->getArguments().slice(0, getNumInductionVars());
38993911
}

flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp

Lines changed: 114 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,9 @@ void collectLoopLiveIns(fir::DoConcurrentLoopOp loop,
137137

138138
liveIns.push_back(operand->get());
139139
});
140+
141+
for (mlir::Value local : loop.getLocalVars())
142+
liveIns.push_back(local);
140143
}
141144

142145
/// Collects values that are local to a loop: "loop-local values". A loop-local
@@ -251,8 +254,7 @@ class DoConcurrentConversion
251254
.getIsTargetDevice();
252255

253256
mlir::omp::TargetOperands targetClauseOps;
254-
genLoopNestClauseOps(doLoop.getLoc(), rewriter, loop, mapper,
255-
loopNestClauseOps,
257+
genLoopNestClauseOps(doLoop.getLoc(), rewriter, loop, loopNestClauseOps,
256258
isTargetDevice ? nullptr : &targetClauseOps);
257259

258260
LiveInShapeInfoMap liveInShapeInfoMap;
@@ -274,14 +276,13 @@ class DoConcurrentConversion
274276
}
275277

276278
mlir::omp::ParallelOp parallelOp =
277-
genParallelOp(doLoop.getLoc(), rewriter, ivInfos, mapper);
279+
genParallelOp(rewriter, loop, ivInfos, mapper);
278280

279281
// Only set as composite when part of `distribute parallel do`.
280282
parallelOp.setComposite(mapToDevice);
281283

282284
if (!mapToDevice)
283-
genLoopNestClauseOps(doLoop.getLoc(), rewriter, loop, mapper,
284-
loopNestClauseOps);
285+
genLoopNestClauseOps(doLoop.getLoc(), rewriter, loop, loopNestClauseOps);
285286

286287
for (mlir::Value local : locals)
287288
looputils::localizeLoopLocalValue(local, parallelOp.getRegion(),
@@ -290,10 +291,38 @@ class DoConcurrentConversion
290291
if (mapToDevice)
291292
genDistributeOp(doLoop.getLoc(), rewriter).setComposite(/*val=*/true);
292293

293-
mlir::omp::LoopNestOp ompLoopNest =
294+
auto [loopNestOp, wsLoopOp] =
294295
genWsLoopOp(rewriter, loop, mapper, loopNestClauseOps,
295296
/*isComposite=*/mapToDevice);
296297

298+
// `local` region arguments are transferred/cloned from the `do concurrent`
299+
// loop to the loopnest op when the region is cloned above. Instead, these
300+
// region arguments should be on the workshare loop's region.
301+
if (mapToDevice) {
302+
for (auto [parallelArg, loopNestArg] : llvm::zip_equal(
303+
parallelOp.getRegion().getArguments(),
304+
loopNestOp.getRegion().getArguments().slice(
305+
loop.getLocalOperandsStart(), loop.getNumLocalOperands())))
306+
rewriter.replaceAllUsesWith(loopNestArg, parallelArg);
307+
308+
for (auto [wsloopArg, loopNestArg] : llvm::zip_equal(
309+
wsLoopOp.getRegion().getArguments(),
310+
loopNestOp.getRegion().getArguments().slice(
311+
loop.getReduceOperandsStart(), loop.getNumReduceOperands())))
312+
rewriter.replaceAllUsesWith(loopNestArg, wsloopArg);
313+
} else {
314+
for (auto [wsloopArg, loopNestArg] :
315+
llvm::zip_equal(wsLoopOp.getRegion().getArguments(),
316+
loopNestOp.getRegion().getArguments().drop_front(
317+
loopNestClauseOps.loopLowerBounds.size())))
318+
rewriter.replaceAllUsesWith(loopNestArg, wsloopArg);
319+
}
320+
321+
for (unsigned i = 0;
322+
i < loop.getLocalVars().size() + loop.getReduceVars().size(); ++i)
323+
loopNestOp.getRegion().eraseArgument(
324+
loopNestClauseOps.loopLowerBounds.size());
325+
297326
rewriter.setInsertionPoint(doLoop);
298327
fir::FirOpBuilder builder(
299328
rewriter,
@@ -314,7 +343,7 @@ class DoConcurrentConversion
314343
// Mark `unordered` loops that are not perfectly nested to be skipped from
315344
// the legality check of the `ConversionTarget` since we are not interested
316345
// in mapping them to OpenMP.
317-
ompLoopNest->walk([&](fir::DoConcurrentOp doLoop) {
346+
loopNestOp->walk([&](fir::DoConcurrentOp doLoop) {
318347
concurrentLoopsToSkip.insert(doLoop);
319348
});
320349

@@ -370,11 +399,21 @@ class DoConcurrentConversion
370399
llvm::DenseMap<mlir::Value, TargetDeclareShapeCreationInfo>;
371400

372401
mlir::omp::ParallelOp
373-
genParallelOp(mlir::Location loc, mlir::ConversionPatternRewriter &rewriter,
402+
genParallelOp(mlir::ConversionPatternRewriter &rewriter,
403+
fir::DoConcurrentLoopOp loop,
374404
looputils::InductionVariableInfos &ivInfos,
375405
mlir::IRMapping &mapper) const {
376-
auto parallelOp = mlir::omp::ParallelOp::create(rewriter, loc);
377-
rewriter.createBlock(&parallelOp.getRegion());
406+
mlir::omp::ParallelOperands parallelOps;
407+
408+
if (mapToDevice)
409+
genPrivatizers(rewriter, mapper, loop, parallelOps);
410+
411+
mlir::Location loc = loop.getLoc();
412+
auto parallelOp = mlir::omp::ParallelOp::create(rewriter, loc, parallelOps);
413+
Fortran::common::openmp::EntryBlockArgs parallelArgs;
414+
parallelArgs.priv.vars = parallelOps.privateVars;
415+
Fortran::common::openmp::genEntryBlock(rewriter, parallelArgs,
416+
parallelOp.getRegion());
378417
rewriter.setInsertionPoint(mlir::omp::TerminatorOp::create(rewriter, loc));
379418

380419
genLoopNestIndVarAllocs(rewriter, ivInfos, mapper);
@@ -411,7 +450,7 @@ class DoConcurrentConversion
411450

412451
void genLoopNestClauseOps(
413452
mlir::Location loc, mlir::ConversionPatternRewriter &rewriter,
414-
fir::DoConcurrentLoopOp loop, mlir::IRMapping &mapper,
453+
fir::DoConcurrentLoopOp loop,
415454
mlir::omp::LoopNestOperands &loopNestClauseOps,
416455
mlir::omp::TargetOperands *targetClauseOps = nullptr) const {
417456
assert(loopNestClauseOps.loopLowerBounds.empty() &&
@@ -440,59 +479,14 @@ class DoConcurrentConversion
440479
loopNestClauseOps.loopInclusive = rewriter.getUnitAttr();
441480
}
442481

443-
mlir::omp::LoopNestOp
482+
std::pair<mlir::omp::LoopNestOp, mlir::omp::WsloopOp>
444483
genWsLoopOp(mlir::ConversionPatternRewriter &rewriter,
445484
fir::DoConcurrentLoopOp loop, mlir::IRMapping &mapper,
446485
const mlir::omp::LoopNestOperands &clauseOps,
447486
bool isComposite) const {
448487
mlir::omp::WsloopOperands wsloopClauseOps;
449-
450-
auto cloneFIRRegionToOMP = [&rewriter](mlir::Region &firRegion,
451-
mlir::Region &ompRegion) {
452-
if (!firRegion.empty()) {
453-
rewriter.cloneRegionBefore(firRegion, ompRegion, ompRegion.begin());
454-
auto firYield =
455-
mlir::cast<fir::YieldOp>(ompRegion.back().getTerminator());
456-
rewriter.setInsertionPoint(firYield);
457-
mlir::omp::YieldOp::create(rewriter, firYield.getLoc(),
458-
firYield.getOperands());
459-
rewriter.eraseOp(firYield);
460-
}
461-
};
462-
463-
// For `local` (and `local_init`) opernads, emit corresponding `private`
464-
// clauses and attach these clauses to the workshare loop.
465-
if (!loop.getLocalVars().empty())
466-
for (auto [op, sym, arg] : llvm::zip_equal(
467-
loop.getLocalVars(),
468-
loop.getLocalSymsAttr().getAsRange<mlir::SymbolRefAttr>(),
469-
loop.getRegionLocalArgs())) {
470-
auto localizer = moduleSymbolTable.lookup<fir::LocalitySpecifierOp>(
471-
sym.getLeafReference());
472-
if (localizer.getLocalitySpecifierType() ==
473-
fir::LocalitySpecifierType::LocalInit)
474-
TODO(localizer.getLoc(),
475-
"local_init conversion is not supported yet");
476-
477-
mlir::OpBuilder::InsertionGuard guard(rewriter);
478-
rewriter.setInsertionPointAfter(localizer);
479-
480-
auto privatizer = mlir::omp::PrivateClauseOp::create(
481-
rewriter, localizer.getLoc(), sym.getLeafReference().str() + ".omp",
482-
localizer.getTypeAttr().getValue(),
483-
mlir::omp::DataSharingClauseType::Private);
484-
485-
cloneFIRRegionToOMP(localizer.getInitRegion(),
486-
privatizer.getInitRegion());
487-
cloneFIRRegionToOMP(localizer.getDeallocRegion(),
488-
privatizer.getDeallocRegion());
489-
490-
moduleSymbolTable.insert(privatizer);
491-
492-
wsloopClauseOps.privateVars.push_back(op);
493-
wsloopClauseOps.privateSyms.push_back(
494-
mlir::SymbolRefAttr::get(privatizer));
495-
}
488+
if (!mapToDevice)
489+
genPrivatizers(rewriter, mapper, loop, wsloopClauseOps);
496490

497491
if (!loop.getReduceVars().empty()) {
498492
for (auto [op, byRef, sym, arg] : llvm::zip_equal(
@@ -515,15 +509,15 @@ class DoConcurrentConversion
515509
rewriter, firReducer.getLoc(), ompReducerName,
516510
firReducer.getTypeAttr().getValue());
517511

518-
cloneFIRRegionToOMP(firReducer.getAllocRegion(),
512+
cloneFIRRegionToOMP(rewriter, firReducer.getAllocRegion(),
519513
ompReducer.getAllocRegion());
520-
cloneFIRRegionToOMP(firReducer.getInitializerRegion(),
514+
cloneFIRRegionToOMP(rewriter, firReducer.getInitializerRegion(),
521515
ompReducer.getInitializerRegion());
522-
cloneFIRRegionToOMP(firReducer.getReductionRegion(),
516+
cloneFIRRegionToOMP(rewriter, firReducer.getReductionRegion(),
523517
ompReducer.getReductionRegion());
524-
cloneFIRRegionToOMP(firReducer.getAtomicReductionRegion(),
518+
cloneFIRRegionToOMP(rewriter, firReducer.getAtomicReductionRegion(),
525519
ompReducer.getAtomicReductionRegion());
526-
cloneFIRRegionToOMP(firReducer.getCleanupRegion(),
520+
cloneFIRRegionToOMP(rewriter, firReducer.getCleanupRegion(),
527521
ompReducer.getCleanupRegion());
528522
moduleSymbolTable.insert(ompReducer);
529523
}
@@ -555,21 +549,10 @@ class DoConcurrentConversion
555549

556550
rewriter.setInsertionPointToEnd(&loopNestOp.getRegion().back());
557551
mlir::omp::YieldOp::create(rewriter, loop->getLoc());
552+
loop->getParentOfType<mlir::ModuleOp>().print(
553+
llvm::errs(), mlir::OpPrintingFlags().assumeVerified());
558554

559-
// `local` region arguments are transferred/cloned from the `do concurrent`
560-
// loop to the loopnest op when the region is cloned above. Instead, these
561-
// region arguments should be on the workshare loop's region.
562-
for (auto [wsloopArg, loopNestArg] :
563-
llvm::zip_equal(wsloopOp.getRegion().getArguments(),
564-
loopNestOp.getRegion().getArguments().drop_front(
565-
clauseOps.loopLowerBounds.size())))
566-
rewriter.replaceAllUsesWith(loopNestArg, wsloopArg);
567-
568-
for (unsigned i = 0;
569-
i < loop.getLocalVars().size() + loop.getReduceVars().size(); ++i)
570-
loopNestOp.getRegion().eraseArgument(clauseOps.loopLowerBounds.size());
571-
572-
return loopNestOp;
555+
return {loopNestOp, wsloopOp};
573556
}
574557

575558
void genBoundsOps(fir::FirOpBuilder &builder, mlir::Value liveIn,
@@ -810,6 +793,59 @@ class DoConcurrentConversion
810793
return distOp;
811794
}
812795

796+
void cloneFIRRegionToOMP(mlir::ConversionPatternRewriter &rewriter,
797+
mlir::Region &firRegion,
798+
mlir::Region &ompRegion) const {
799+
if (!firRegion.empty()) {
800+
rewriter.cloneRegionBefore(firRegion, ompRegion, ompRegion.begin());
801+
auto firYield =
802+
mlir::cast<fir::YieldOp>(ompRegion.back().getTerminator());
803+
rewriter.setInsertionPoint(firYield);
804+
mlir::omp::YieldOp::create(rewriter, firYield.getLoc(),
805+
firYield.getOperands());
806+
rewriter.eraseOp(firYield);
807+
}
808+
}
809+
810+
void genPrivatizers(mlir::ConversionPatternRewriter &rewriter,
811+
mlir::IRMapping &mapper, fir::DoConcurrentLoopOp loop,
812+
mlir::omp::PrivateClauseOps &privateClauseOps) const {
813+
// For `local` (and `local_init`) operands, emit corresponding `private`
814+
// clauses and attach these clauses to the workshare loop.
815+
if (!loop.getLocalVars().empty())
816+
for (auto [var, sym, arg] : llvm::zip_equal(
817+
loop.getLocalVars(),
818+
loop.getLocalSymsAttr().getAsRange<mlir::SymbolRefAttr>(),
819+
loop.getRegionLocalArgs())) {
820+
auto localizer = moduleSymbolTable.lookup<fir::LocalitySpecifierOp>(
821+
sym.getLeafReference());
822+
if (localizer.getLocalitySpecifierType() ==
823+
fir::LocalitySpecifierType::LocalInit)
824+
TODO(localizer.getLoc(),
825+
"local_init conversion is not supported yet");
826+
827+
mlir::OpBuilder::InsertionGuard guard(rewriter);
828+
rewriter.setInsertionPointAfter(localizer);
829+
830+
auto privatizer = mlir::omp::PrivateClauseOp::create(
831+
rewriter, localizer.getLoc(), sym.getLeafReference().str() + ".omp",
832+
localizer.getTypeAttr().getValue(),
833+
mlir::omp::DataSharingClauseType::Private);
834+
835+
cloneFIRRegionToOMP(rewriter, localizer.getInitRegion(),
836+
privatizer.getInitRegion());
837+
cloneFIRRegionToOMP(rewriter, localizer.getDeallocRegion(),
838+
privatizer.getDeallocRegion());
839+
840+
moduleSymbolTable.insert(privatizer);
841+
842+
privateClauseOps.privateVars.push_back(mapToDevice ? mapper.lookup(var)
843+
: var);
844+
privateClauseOps.privateSyms.push_back(
845+
mlir::SymbolRefAttr::get(privatizer));
846+
}
847+
}
848+
813849
bool mapToDevice;
814850
llvm::DenseSet<fir::DoConcurrentOp> &concurrentLoopsToSkip;
815851
mlir::SymbolTable &moduleSymbolTable;
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
// RUN: fir-opt --omp-do-concurrent-conversion="map-to=device" %s -o - | FileCheck %s
2+
3+
fir.local {type = local} @_QFfooEmy_local_private_f32 : f32
4+
5+
func.func @_QPfoo() {
6+
%0 = fir.dummy_scope : !fir.dscope
7+
%3 = fir.alloca f32 {bindc_name = "my_local", uniq_name = "_QFfooEmy_local"}
8+
%4:2 = hlfir.declare %3 {uniq_name = "_QFfooEmy_local"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
9+
10+
%c1 = arith.constant 1 : index
11+
%c10 = arith.constant 10 : index
12+
13+
fir.do_concurrent {
14+
%7 = fir.alloca i32 {bindc_name = "i"}
15+
%8:2 = hlfir.declare %7 {uniq_name = "_QFfooEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
16+
17+
fir.do_concurrent.loop (%arg0) = (%c1) to (%c10) step (%c1) local(@_QFfooEmy_local_private_f32 %4#0 -> %arg1 : !fir.ref<f32>) {
18+
%9 = fir.convert %arg0 : (index) -> i32
19+
fir.store %9 to %8#0 : !fir.ref<i32>
20+
%10:2 = hlfir.declare %arg1 {uniq_name = "_QFfooEmy_local"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
21+
%cst = arith.constant 4.200000e+01 : f32
22+
hlfir.assign %cst to %10#0 : f32, !fir.ref<f32>
23+
}
24+
}
25+
return
26+
}
27+
28+
// CHECK: omp.private {type = private} @[[OMP_PRIVATIZER:.*.omp]] : f32
29+
30+
// CHECK: %[[LOCAL_DECL:.*]]:2 = hlfir.declare %{{.*}} {uniq_name = "{{.*}}my_local"}
31+
// CHECK: %[[LOCAL_MAP:.*]] = omp.map.info var_ptr(%[[LOCAL_DECL]]#1 : {{.*}})
32+
33+
// CHECK: omp.target host_eval({{.*}}) map_entries({{.*}}, %[[LOCAL_MAP]] -> %[[LOCAL_MAP_ARG:.*]] : {{.*}}) {
34+
// CHECK: %[[LOCAL_DEV_DECL:.*]]:2 = hlfir.declare %[[LOCAL_MAP_ARG]] {uniq_name = "_QFfooEmy_local"}
35+
36+
// CHECK: omp.teams {
37+
// CHECK: omp.parallel private(@[[OMP_PRIVATIZER]] %[[LOCAL_DEV_DECL]]#0 -> %[[LOCAL_PRIV_ARG:.*]] : {{.*}}) {
38+
// CHECK: omp.distribute {
39+
// CHECK: omp.wsloop {
40+
// CHECK: omp.loop_nest {{.*}} {
41+
// CHECK: %[[LOCAL_LOOP_DECL:.*]]:2 = hlfir.declare %[[LOCAL_PRIV_ARG]] {uniq_name = "_QFfooEmy_local"}
42+
// CHECK: hlfir.assign %{{.*}} to %[[LOCAL_LOOP_DECL]]#0
43+
// CHECK: omp.yield
44+
// CHECK: }
45+
// CHECK: }
46+
// CHECK: }
47+
// CHECK: }
48+
// CHECK: }
49+
// CHECK: }

0 commit comments

Comments
 (0)