@@ -140,6 +140,9 @@ void collectLoopLiveIns(fir::DoConcurrentLoopOp loop,
140140
141141 for (mlir::Value local : loop.getLocalVars ())
142142 liveIns.push_back (local);
143+
144+ for (mlir::Value reduce : loop.getReduceVars ())
145+ liveIns.push_back (reduce);
143146}
144147
145148// / Collects values that are local to a loop: "loop-local values". A loop-local
@@ -272,7 +275,7 @@ class DoConcurrentConversion
272275 targetOp =
273276 genTargetOp (doLoop.getLoc (), rewriter, mapper, loopNestLiveIns,
274277 targetClauseOps, loopNestClauseOps, liveInShapeInfoMap);
275- genTeamsOp (doLoop. getLoc (), rewriter );
278+ genTeamsOp (rewriter, loop, mapper );
276279 }
277280
278281 mlir::omp::ParallelOp parallelOp =
@@ -488,46 +491,7 @@ class DoConcurrentConversion
488491 if (!mapToDevice)
489492 genPrivatizers (rewriter, mapper, loop, wsloopClauseOps);
490493
491- if (!loop.getReduceVars ().empty ()) {
492- for (auto [op, byRef, sym, arg] : llvm::zip_equal (
493- loop.getReduceVars (), loop.getReduceByrefAttr ().asArrayRef (),
494- loop.getReduceSymsAttr ().getAsRange <mlir::SymbolRefAttr>(),
495- loop.getRegionReduceArgs ())) {
496- auto firReducer = moduleSymbolTable.lookup <fir::DeclareReductionOp>(
497- sym.getLeafReference ());
498-
499- mlir::OpBuilder::InsertionGuard guard (rewriter);
500- rewriter.setInsertionPointAfter (firReducer);
501- std::string ompReducerName = sym.getLeafReference ().str () + " .omp" ;
502-
503- auto ompReducer =
504- moduleSymbolTable.lookup <mlir::omp::DeclareReductionOp>(
505- rewriter.getStringAttr (ompReducerName));
506-
507- if (!ompReducer) {
508- ompReducer = mlir::omp::DeclareReductionOp::create (
509- rewriter, firReducer.getLoc (), ompReducerName,
510- firReducer.getTypeAttr ().getValue ());
511-
512- cloneFIRRegionToOMP (rewriter, firReducer.getAllocRegion (),
513- ompReducer.getAllocRegion ());
514- cloneFIRRegionToOMP (rewriter, firReducer.getInitializerRegion (),
515- ompReducer.getInitializerRegion ());
516- cloneFIRRegionToOMP (rewriter, firReducer.getReductionRegion (),
517- ompReducer.getReductionRegion ());
518- cloneFIRRegionToOMP (rewriter, firReducer.getAtomicReductionRegion (),
519- ompReducer.getAtomicReductionRegion ());
520- cloneFIRRegionToOMP (rewriter, firReducer.getCleanupRegion (),
521- ompReducer.getCleanupRegion ());
522- moduleSymbolTable.insert (ompReducer);
523- }
524-
525- wsloopClauseOps.reductionVars .push_back (op);
526- wsloopClauseOps.reductionByref .push_back (byRef);
527- wsloopClauseOps.reductionSyms .push_back (
528- mlir::SymbolRefAttr::get (ompReducer));
529- }
530- }
494+ genReductions (rewriter, mapper, loop, wsloopClauseOps);
531495
532496 auto wsloopOp =
533497 mlir::omp::WsloopOp::create (rewriter, loop.getLoc (), wsloopClauseOps);
@@ -549,8 +513,6 @@ class DoConcurrentConversion
549513
550514 rewriter.setInsertionPointToEnd (&loopNestOp.getRegion ().back ());
551515 mlir::omp::YieldOp::create (rewriter, loop->getLoc ());
552- loop->getParentOfType <mlir::ModuleOp>().print (
553- llvm::errs (), mlir::OpPrintingFlags ().assumeVerified ());
554516
555517 return {loopNestOp, wsloopOp};
556518 }
@@ -771,15 +733,26 @@ class DoConcurrentConversion
771733 liveInName, shape);
772734 }
773735
774- mlir::omp::TeamsOp
775- genTeamsOp (mlir::Location loc,
776- mlir::ConversionPatternRewriter &rewriter) const {
777- auto teamsOp = rewriter.create <mlir::omp::TeamsOp>(
778- loc, /* clauses=*/ mlir::omp::TeamsOperands{});
736+ mlir::omp::TeamsOp genTeamsOp (mlir::ConversionPatternRewriter &rewriter,
737+ fir::DoConcurrentLoopOp loop,
738+ mlir::IRMapping &mapper) const {
739+ mlir::omp::TeamsOperands teamsOps;
740+ genReductions (rewriter, mapper, loop, teamsOps);
741+
742+ mlir::Location loc = loop.getLoc ();
743+ auto teamsOp = rewriter.create <mlir::omp::TeamsOp>(loc, teamsOps);
744+ Fortran::common::openmp::EntryBlockArgs teamsArgs;
745+ teamsArgs.reduction .vars = teamsOps.reductionVars ;
746+ Fortran::common::openmp::genEntryBlock (rewriter, teamsArgs,
747+ teamsOp.getRegion ());
779748
780- rewriter.createBlock (&teamsOp.getRegion ());
781749 rewriter.setInsertionPoint (rewriter.create <mlir::omp::TerminatorOp>(loc));
782750
751+ for (auto [loopVar, teamsArg] : llvm::zip_equal (
752+ loop.getReduceVars (), teamsOp.getRegion ().getArguments ())) {
753+ mapper.map (loopVar, teamsArg);
754+ }
755+
783756 return teamsOp;
784757 }
785758
@@ -846,6 +819,52 @@ class DoConcurrentConversion
846819 }
847820 }
848821
822+ void genReductions (mlir::ConversionPatternRewriter &rewriter,
823+ mlir::IRMapping &mapper, fir::DoConcurrentLoopOp loop,
824+ mlir::omp::ReductionClauseOps &reductionClauseOps) const {
825+ if (!loop.getReduceVars ().empty ()) {
826+ for (auto [var, byRef, sym, arg] : llvm::zip_equal (
827+ loop.getReduceVars (), loop.getReduceByrefAttr ().asArrayRef (),
828+ loop.getReduceSymsAttr ().getAsRange <mlir::SymbolRefAttr>(),
829+ loop.getRegionReduceArgs ())) {
830+ auto firReducer = moduleSymbolTable.lookup <fir::DeclareReductionOp>(
831+ sym.getLeafReference ());
832+
833+ mlir::OpBuilder::InsertionGuard guard (rewriter);
834+ rewriter.setInsertionPointAfter (firReducer);
835+ std::string ompReducerName = sym.getLeafReference ().str () + " .omp" ;
836+
837+ auto ompReducer =
838+ moduleSymbolTable.lookup <mlir::omp::DeclareReductionOp>(
839+ rewriter.getStringAttr (ompReducerName));
840+
841+ if (!ompReducer) {
842+ ompReducer = mlir::omp::DeclareReductionOp::create (
843+ rewriter, firReducer.getLoc (), ompReducerName,
844+ firReducer.getTypeAttr ().getValue ());
845+
846+ cloneFIRRegionToOMP (rewriter, firReducer.getAllocRegion (),
847+ ompReducer.getAllocRegion ());
848+ cloneFIRRegionToOMP (rewriter, firReducer.getInitializerRegion (),
849+ ompReducer.getInitializerRegion ());
850+ cloneFIRRegionToOMP (rewriter, firReducer.getReductionRegion (),
851+ ompReducer.getReductionRegion ());
852+ cloneFIRRegionToOMP (rewriter, firReducer.getAtomicReductionRegion (),
853+ ompReducer.getAtomicReductionRegion ());
854+ cloneFIRRegionToOMP (rewriter, firReducer.getCleanupRegion (),
855+ ompReducer.getCleanupRegion ());
856+ moduleSymbolTable.insert (ompReducer);
857+ }
858+
859+ reductionClauseOps.reductionVars .push_back (
860+ mapToDevice ? mapper.lookup (var) : var);
861+ reductionClauseOps.reductionByref .push_back (byRef);
862+ reductionClauseOps.reductionSyms .push_back (
863+ mlir::SymbolRefAttr::get (ompReducer));
864+ }
865+ }
866+ }
867+
849868 bool mapToDevice;
850869 llvm::DenseSet<fir::DoConcurrentOp> &concurrentLoopsToSkip;
851870 mlir::SymbolTable &moduleSymbolTable;
0 commit comments