Skip to content

Commit b99122b

Browse files
committed
[flang][OpenMP] Basic mapping of do concurrent ... reduce to OpenMP
Now that we have changes introduced by #145837, mapping reductions from `do concurrent` to OpenMP is almost trivial. This PR adds such mapping. TODO: Add tests
1 parent 5f665c9 commit b99122b

File tree

1 file changed

+56
-27
lines changed

1 file changed

+56
-27
lines changed

flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp

Lines changed: 56 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,19 @@ class DoConcurrentConversion
312312
bool isComposite) const {
313313
mlir::omp::WsloopOperands wsloopClauseOps;
314314

315+
auto cloneFIRRegionToOMP = [&rewriter](mlir::Region &firRegion,
316+
mlir::Region &ompRegion) {
317+
if (!firRegion.empty()) {
318+
rewriter.cloneRegionBefore(firRegion, ompRegion, ompRegion.begin());
319+
auto firYield =
320+
mlir::cast<fir::YieldOp>(ompRegion.back().getTerminator());
321+
rewriter.setInsertionPoint(firYield);
322+
rewriter.create<mlir::omp::YieldOp>(firYield.getLoc(),
323+
firYield.getOperands());
324+
rewriter.eraseOp(firYield);
325+
}
326+
};
327+
315328
// For `local` (and `local_init`) opernads, emit corresponding `private`
316329
// clauses and attach these clauses to the workshare loop.
317330
if (!loop.getLocalVars().empty())
@@ -326,50 +339,65 @@ class DoConcurrentConversion
326339
TODO(localizer.getLoc(),
327340
"local_init conversion is not supported yet");
328341

329-
auto oldIP = rewriter.saveInsertionPoint();
342+
mlir::OpBuilder::InsertionGuard guard(rewriter);
330343
rewriter.setInsertionPointAfter(localizer);
344+
331345
auto privatizer = rewriter.create<mlir::omp::PrivateClauseOp>(
332346
localizer.getLoc(), sym.getLeafReference().str() + ".omp",
333347
localizer.getTypeAttr().getValue(),
334348
mlir::omp::DataSharingClauseType::Private);
335349

336-
if (!localizer.getInitRegion().empty()) {
337-
rewriter.cloneRegionBefore(localizer.getInitRegion(),
338-
privatizer.getInitRegion(),
339-
privatizer.getInitRegion().begin());
340-
auto firYield = mlir::cast<fir::YieldOp>(
341-
privatizer.getInitRegion().back().getTerminator());
342-
rewriter.setInsertionPoint(firYield);
343-
rewriter.create<mlir::omp::YieldOp>(firYield.getLoc(),
344-
firYield.getOperands());
345-
rewriter.eraseOp(firYield);
346-
}
347-
348-
if (!localizer.getDeallocRegion().empty()) {
349-
rewriter.cloneRegionBefore(localizer.getDeallocRegion(),
350-
privatizer.getDeallocRegion(),
351-
privatizer.getDeallocRegion().begin());
352-
auto firYield = mlir::cast<fir::YieldOp>(
353-
privatizer.getDeallocRegion().back().getTerminator());
354-
rewriter.setInsertionPoint(firYield);
355-
rewriter.create<mlir::omp::YieldOp>(firYield.getLoc(),
356-
firYield.getOperands());
357-
rewriter.eraseOp(firYield);
358-
}
359-
360-
rewriter.restoreInsertionPoint(oldIP);
350+
cloneFIRRegionToOMP(localizer.getInitRegion(),
351+
privatizer.getInitRegion());
352+
cloneFIRRegionToOMP(localizer.getDeallocRegion(),
353+
privatizer.getDeallocRegion());
361354

362355
wsloopClauseOps.privateVars.push_back(op);
363356
wsloopClauseOps.privateSyms.push_back(
364357
mlir::SymbolRefAttr::get(privatizer));
365358
}
366359

360+
if (!loop.getReduceVars().empty()) {
361+
for (auto [op, byRef, sym, arg] : llvm::zip_equal(
362+
loop.getReduceVars(), loop.getReduceByrefAttr().asArrayRef(),
363+
loop.getReduceSymsAttr().getAsRange<mlir::SymbolRefAttr>(),
364+
loop.getRegionReduceArgs())) {
365+
auto firReducer =
366+
mlir::SymbolTable::lookupNearestSymbolFrom<fir::DeclareReductionOp>(
367+
loop, sym);
368+
369+
mlir::OpBuilder::InsertionGuard guard(rewriter);
370+
rewriter.setInsertionPointAfter(firReducer);
371+
372+
auto ompReducer = rewriter.create<mlir::omp::DeclareReductionOp>(
373+
firReducer.getLoc(), sym.getLeafReference().str() + ".omp",
374+
firReducer.getTypeAttr().getValue());
375+
376+
cloneFIRRegionToOMP(firReducer.getAllocRegion(),
377+
ompReducer.getAllocRegion());
378+
cloneFIRRegionToOMP(firReducer.getInitializerRegion(),
379+
ompReducer.getInitializerRegion());
380+
cloneFIRRegionToOMP(firReducer.getReductionRegion(),
381+
ompReducer.getReductionRegion());
382+
cloneFIRRegionToOMP(firReducer.getAtomicReductionRegion(),
383+
ompReducer.getAtomicReductionRegion());
384+
cloneFIRRegionToOMP(firReducer.getCleanupRegion(),
385+
ompReducer.getCleanupRegion());
386+
387+
wsloopClauseOps.reductionVars.push_back(op);
388+
wsloopClauseOps.reductionByref.push_back(byRef);
389+
wsloopClauseOps.reductionSyms.push_back(
390+
mlir::SymbolRefAttr::get(ompReducer));
391+
}
392+
}
393+
367394
auto wsloopOp =
368395
rewriter.create<mlir::omp::WsloopOp>(loop.getLoc(), wsloopClauseOps);
369396
wsloopOp.setComposite(isComposite);
370397

371398
Fortran::common::openmp::EntryBlockArgs wsloopArgs;
372399
wsloopArgs.priv.vars = wsloopClauseOps.privateVars;
400+
wsloopArgs.reduction.vars = wsloopClauseOps.reductionVars;
373401
Fortran::common::openmp::genEntryBlock(rewriter, wsloopArgs,
374402
wsloopOp.getRegion());
375403

@@ -393,7 +421,8 @@ class DoConcurrentConversion
393421
clauseOps.loopLowerBounds.size())))
394422
rewriter.replaceAllUsesWith(loopNestArg, wsloopArg);
395423

396-
for (unsigned i = 0; i < loop.getLocalVars().size(); ++i)
424+
for (unsigned i = 0;
425+
i < loop.getLocalVars().size() + loop.getReduceVars().size(); ++i)
397426
loopNestOp.getRegion().eraseArgument(clauseOps.loopLowerBounds.size());
398427

399428
return loopNestOp;

0 commit comments

Comments
 (0)